harvesthealth commited on
Commit
076c26d
·
verified ·
1 Parent(s): ff41489

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/icon.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/overall_process_crop.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/webvoyager_overall_res.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,21 @@
1
  ---
2
- title: Webvoyager
3
- emoji: 🌖
4
  colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
9
- pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WebVoyager
3
+ emoji: 🚀
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.12.0
8
  app_file: app.py
 
9
  ---
10
 
11
+ # WebVoyager
12
+
13
+ This is a Gradio demo of the WebVoyager agent, which can complete user instructions end-to-end by interacting with real-world websites.
14
+
15
+ ## How to use
16
+
17
+ 1. Enter the URL of the website you want the agent to interact with.
18
+ 2. Describe the task you want the agent to perform.
19
+ 3. Click "Submit" and watch the agent work in real-time.
20
+
21
+ **Note:** This is a research demo and may not always succeed. The agent's performance depends on the complexity of the task and the website's structure.
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import threading
7
+ from run import webvoyager_run
8
+
9
+ def run_script_for_gradio(url, task):
10
+ """
11
+ A wrapper to run the webvoyager script for Gradio, capturing output.
12
+ """
13
+ # Create a temporary directory for this run
14
+ with tempfile.TemporaryDirectory() as temp_dir:
15
+ # Create a temporary file for the task
16
+ task_file_path = os.path.join(temp_dir, 'task.jsonl')
17
+ with open(task_file_path, 'w') as f:
18
+ f.write(f'{{"id": "custom_task", "web": "{url}", "ques": "{task}"}}')
19
+
20
+ # Setup arguments
21
+ args = argparse.Namespace(
22
+ test_file=task_file_path,
23
+ max_iter=5,
24
+ api_key=os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
25
+ api_model="gpt-4-vision-preview",
26
+ output_dir=os.path.join(temp_dir, 'results'),
27
+ seed=None,
28
+ max_attached_imgs=1,
29
+ temperature=1.0,
30
+ download_dir=os.path.join(temp_dir, 'downloads'),
31
+ text_only=False,
32
+ headless=True,
33
+ save_accessibility_tree=False,
34
+ force_device_scale=False,
35
+ window_width=1024,
36
+ window_height=768,
37
+ fix_box_color=False
38
+ )
39
+
40
+ # Create necessary directories
41
+ os.makedirs(args.output_dir, exist_ok=True)
42
+ os.makedirs(args.download_dir, exist_ok=True)
43
+
44
+ # Capture stdout
45
+ original_stdout = sys.stdout
46
+ sys.stdout = captured_output = threading.local()
47
+ captured_output.value = ""
48
+
49
+ def stream_catcher():
50
+ while True:
51
+ try:
52
+ # This is a bit of a hack to get real-time output
53
+ # A better solution would be to modify the logger to write to a queue
54
+ with open(os.path.join(args.output_dir, 'taskcustom_task', 'agent.log'), 'r') as f:
55
+ captured_output.value = f.read()
56
+ except FileNotFoundError:
57
+ pass # Log file might not be created yet
58
+
59
+ stream_thread = threading.Thread(target=stream_catcher)
60
+ stream_thread.daemon = True
61
+ stream_thread.start()
62
+
63
+ output = ""
64
+ try:
65
+ with open(task_file_path, 'r', encoding='utf-8') as f:
66
+ task_data = [json.loads(line) for line in f]
67
+
68
+ for task_item in task_data:
69
+ task_dir = os.path.join(args.output_dir, f'task{task_item["id"]}')
70
+ os.makedirs(task_dir, exist_ok=True)
71
+ webvoyager_run(args, task_item, task_dir)
72
+
73
+ # Read the final log
74
+ with open(os.path.join(task_dir, 'agent.log'), 'r') as f:
75
+ output = f.read()
76
+ yield output
77
+
78
+ except Exception as e:
79
+ output = f"An error occurred: {e}"
80
+ yield output
81
+ finally:
82
+ sys.stdout = original_stdout
83
+
84
+
85
+ iface = gr.Interface(
86
+ fn=run_script_for_gradio,
87
+ inputs=[
88
+ gr.Textbox(label="URL", placeholder="Enter the URL of the website"),
89
+ gr.Textbox(label="Task", placeholder="Describe the task to perform"),
90
+ ],
91
+ outputs=gr.Textbox(label="Agent Output", lines=20, interactive=False),
92
+ title="WebVoyager",
93
+ description="An LMM-powered web agent that can complete user instructions end-to-end.",
94
+ )
95
+
96
+ if __name__ == "__main__":
97
+ iface.launch()
assets/icon.png ADDED

Git LFS Details

  • SHA256: 98f54a18ccd946b8df0bfb04f0e91a19133852de54398272745887b024b4f9b6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
assets/overall_process_crop.png ADDED

Git LFS Details

  • SHA256: 85e7b5549b14270919bf06bf516243566e34ab466e4de42c676e9fcc50ae9eab
  • Pointer size: 132 Bytes
  • Size of remote file: 2.14 MB
assets/webvoyager_overall_res.png ADDED

Git LFS Details

  • SHA256: 38d2fc248971f80b2c76a5350c5e5d8425fe0162f06b9f737970b2dc2d919ae2
  • Pointer size: 131 Bytes
  • Size of remote file: 835 kB
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chromium-driver
prompts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM_PROMPT = """Imagine you are a robot browsing the web, just like humans. Now you need to complete a task. In each iteration, you will receive an Observation that includes a screenshot of a webpage and some texts. This screenshot will feature Numerical Labels placed in the TOP LEFT corner of each Web Element.
2
+ Carefully analyze the visual information to identify the Numerical Label corresponding to the Web Element that requires interaction, then follow the guidelines and choose one of the following actions:
3
+ 1. Click a Web Element.
4
+ 2. Delete existing content in a textbox and then type content.
5
+ 3. Scroll up or down. Multiple scrolls are allowed to browse the webpage. Pay attention!! The default scroll is the whole window. If the scroll widget is located in a certain area of the webpage, then you have to specify a Web Element in that area. I would hover the mouse there and then scroll.
6
+ 4. Wait. Typically used to wait for unfinished webpage processes, with a duration of 5 seconds.
7
+ 5. Go back, returning to the previous webpage.
8
+ 6. Google, directly jump to the Google search page. When you can't find information in some websites, try starting over with Google.
9
+ 7. Answer. This action should only be chosen when all questions in the task have been solved.
10
+
11
+ Correspondingly, Action should STRICTLY follow the format:
12
+ - Click [Numerical_Label]
13
+ - Type [Numerical_Label]; [Content]
14
+ - Scroll [Numerical_Label or WINDOW]; [up or down]
15
+ - Wait
16
+ - GoBack
17
+ - Google
18
+ - ANSWER; [content]
19
+
20
+ Key Guidelines You MUST follow:
21
+ * Action guidelines *
22
+ 1) To input text, NO need to click textbox first, directly type content. After typing, the system automatically hits `ENTER` key. Sometimes you should click the search button to apply search filters. Try to use simple language when searching.
23
+ 2) You must Distinguish between textbox and search button, don't type content into the button! If no textbox is found, you may need to click the search button first before the textbox is displayed.
24
+ 3) Execute only one action per iteration.
25
+ 4) STRICTLY Avoid repeating the same action if the webpage remains unchanged. You may have selected the wrong web element or numerical label. Continuous use of the Wait is also NOT allowed.
26
+ 5) When a complex Task involves multiple questions or steps, select "ANSWER" only at the very end, after addressing all of these questions (steps). Flexibly combine your own abilities with the information in the web page. Double check the formatting requirements in the task when ANSWER.
27
+ * Web Browsing Guidelines *
28
+ 1) Don't interact with useless web elements like Login, Sign-in, donation that appear in Webpages. Pay attention to Key Web Elements like search textbox and menu.
29
+ 2) Vsit video websites like YouTube is allowed BUT you can't play videos. Clicking to download PDF is allowed and will be analyzed by the Assistant API.
30
+ 3) Focus on the numerical labels in the TOP LEFT corner of each rectangle (element). Ensure you don't mix them up with other numbers (e.g. Calendar) on the page.
31
+ 4) Focus on the date in task, you must look for results that match the date. It may be necessary to find the correct year, month and day at calendar.
32
+ 5) Pay attention to the filter and sort functions on the page, which, combined with scroll, can help you solve conditions like 'highest', 'cheapest', 'lowest', 'earliest', etc. Try your best to find the answer that best fits the task.
33
+
34
+ Your reply should strictly follow the format:
35
+ Thought: {Your brief thoughts (briefly summarize the info that will help ANSWER)}
36
+ Action: {One Action format you choose}
37
+
38
+ Then the User will provide:
39
+ Observation: {A labeled screenshot Given by User}"""
40
+
41
+
42
+ SYSTEM_PROMPT_TEXT_ONLY = """Imagine you are a robot browsing the web, just like humans. Now you need to complete a task. In each iteration, you will receive an Accessibility Tree with numerical label representing information about the page, then follow the guidelines and choose one of the following actions:
43
+ 1. Click a Web Element.
44
+ 2. Delete existing content in a textbox and then type content.
45
+ 3. Scroll up or down. Multiple scrolls are allowed to browse the webpage. Pay attention!! The default scroll is the whole window. If the scroll widget is located in a certain area of the webpage, then you have to specify a Web Element in that area. I would hover the mouse there and then scroll.
46
+ 4. Wait. Typically used to wait for unfinished webpage processes, with a duration of 5 seconds.
47
+ 5. Go back, returning to the previous webpage.
48
+ 6. Google, directly jump to the Google search page. When you can't find information in some websites, try starting over with Google.
49
+ 7. Answer. This action should only be chosen when all questions in the task have been solved.
50
+
51
+ Correspondingly, Action should STRICTLY follow the format:
52
+ - Click [Numerical_Label]
53
+ - Type [Numerical_Label]; [Content]
54
+ - Scroll [Numerical_Label or WINDOW]; [up or down]
55
+ - Wait
56
+ - GoBack
57
+ - Google
58
+ - ANSWER; [content]
59
+
60
+ Key Guidelines You MUST follow:
61
+ * Action guidelines *
62
+ 1) To input text, NO need to click textbox first, directly type content. After typing, the system automatically hits `ENTER` key. Sometimes you should click the search button to apply search filters. Try to use simple language when searching.
63
+ 2) You must Distinguish between textbox and search button, don't type content into the button! If no textbox is found, you may need to click the search button first before the textbox is displayed.
64
+ 3) Execute only one action per iteration.
65
+ 4) STRICTLY Avoid repeating the same action if the webpage remains unchanged. You may have selected the wrong web element or numerical label. Continuous use of the Wait is also NOT allowed.
66
+ 5) When a complex Task involves multiple questions or steps, select "ANSWER" only at the very end, after addressing all of these questions (steps). Flexibly combine your own abilities with the information in the web page. Double check the formatting requirements in the task when ANSWER.
67
+ * Web Browsing Guidelines *
68
+ 1) Don't interact with useless web elements like Login, Sign-in, donation that appear in Webpages. Pay attention to Key Web Elements like search textbox and menu.
69
+ 2) Vsit video websites like YouTube is allowed BUT you can't play videos. Clicking to download PDF is allowed and will be analyzed by the Assistant API.
70
+ 3) Focus on the date in task, you must look for results that match the date. It may be necessary to find the correct year, month and day at calendar.
71
+ 4) Pay attention to the filter and sort functions on the page, which, combined with scroll, can help you solve conditions like 'highest', 'cheapest', 'lowest', 'earliest', etc. Try your best to find the answer that best fits the task.
72
+
73
+ Your reply should strictly follow the format:
74
+ Thought: {Your brief thoughts (briefly summarize the info that will help ANSWER)}
75
+ Action: {One Action format you choose}
76
+
77
+ Then the User will provide:
78
+ Observation: {Accessibility Tree of a web page}"""
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openai==1.1.1
2
+ selenium==4.15.2
3
+ pillow==10.1.0
4
+ gradio
run.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import argparse
3
+ import time
4
+ import json
5
+ import re
6
+ import os
7
+ import shutil
8
+ import logging
9
+
10
+ from selenium import webdriver
11
+ from selenium.webdriver.common.by import By
12
+ from selenium.webdriver.common.keys import Keys
13
+ from selenium.webdriver.common.action_chains import ActionChains
14
+
15
+ from prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_TEXT_ONLY
16
+ from openai import OpenAI
17
+ from utils import get_web_element_rect, encode_image, extract_information, print_message,\
18
+ get_webarena_accessibility_tree, get_pdf_retrieval_ans_from_assistant, clip_message_and_obs, clip_message_and_obs_text_only
19
+
20
+
21
+ def setup_logger(folder_path):
22
+ log_file_path = os.path.join(folder_path, 'agent.log')
23
+
24
+ logger = logging.getLogger()
25
+ for handler in logger.handlers[:]:
26
+ logger.removeHandler(handler)
27
+ handler.close()
28
+
29
+ # Capture all logs to a file
30
+ handler = logging.FileHandler(log_file_path)
31
+ formatter = logging.Formatter('%(levelname)s - %(message)s')
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+ logger.setLevel(logging.INFO)
35
+
36
+ # Also stream logs to stdout for real-time viewing in Gradio
37
+ stream_handler = logging.StreamHandler()
38
+ stream_handler.setFormatter(formatter)
39
+ logger.addHandler(stream_handler)
40
+
41
+
42
+ def driver_config(args):
43
+ options = webdriver.ChromeOptions()
44
+
45
+ if args.save_accessibility_tree:
46
+ args.force_device_scale = True
47
+
48
+ if args.force_device_scale:
49
+ options.add_argument("--force-device-scale-factor=1")
50
+ if args.headless:
51
+ options.add_argument("--headless")
52
+ options.add_argument("--no-sandbox") # Required for Hugging Face Spaces
53
+ options.add_argument("--disable-dev-shm-usage") # Required for Hugging Face Spaces
54
+ options.add_argument(
55
+ "--user-agent=Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36"
56
+ )
57
+ options.add_experimental_option(
58
+ "prefs", {
59
+ "download.default_directory": args.download_dir,
60
+ "plugins.always_open_pdf_externally": True
61
+ }
62
+ )
63
+ return options
64
+
65
+
66
+ def format_msg(it, init_msg, pdf_obs, warn_obs, web_img_b64, web_text):
67
+ if it == 1:
68
+ init_msg += f"I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"
69
+ init_msg_format = {
70
+ 'role': 'user',
71
+ 'content': [
72
+ {'type': 'text', 'text': init_msg},
73
+ ]
74
+ }
75
+ init_msg_format['content'].append({"type": "image_url",
76
+ "image_url": {"url": f"data:image/png;base64,{web_img_b64}"}})
77
+ return init_msg_format
78
+ else:
79
+ if not pdf_obs:
80
+ curr_msg = {
81
+ 'role': 'user',
82
+ 'content': [
83
+ {'type': 'text', 'text': f"Observation:{warn_obs} please analyze the attached screenshot and give the Thought and Action. I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"},
84
+ {
85
+ 'type': 'image_url',
86
+ 'image_url': {"url": f"data:image/png;base64,{web_img_b64}"}
87
+ }
88
+ ]
89
+ }
90
+ else:
91
+ curr_msg = {
92
+ 'role': 'user',
93
+ 'content': [
94
+ {'type': 'text', 'text': f"Observation: {pdf_obs} Please analyze the response given by Assistant, then consider whether to continue iterating or not. The screenshot of the current page is also attached, give the Thought and Action. I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"},
95
+ {
96
+ 'type': 'image_url',
97
+ 'image_url': {"url": f"data:image/png;base64,{web_img_b64}"}
98
+ }
99
+ ]
100
+ }
101
+ return curr_msg
102
+
103
+
104
+ def format_msg_text_only(it, init_msg, pdf_obs, warn_obs, ac_tree):
105
+ if it == 1:
106
+ init_msg_format = {
107
+ 'role': 'user',
108
+ 'content': init_msg + '\n' + ac_tree
109
+ }
110
+ return init_msg_format
111
+ else:
112
+ if not pdf_obs:
113
+ curr_msg = {
114
+ 'role': 'user',
115
+ 'content': f"Observation:{warn_obs} please analyze the accessibility tree and give the Thought and Action.\n{ac_tree}"
116
+ }
117
+ else:
118
+ curr_msg = {
119
+ 'role': 'user',
120
+ 'content': f"Observation: {pdf_obs} Please analyze the response given by Assistant, then consider whether to continue iterating or not. The accessibility tree of the current page is also given, give the Thought and Action.\n{ac_tree}"
121
+ }
122
+ return curr_msg
123
+
124
+
125
+ def call_gpt4v_api(args, openai_client, messages):
126
+ retry_times = 0
127
+ while True:
128
+ try:
129
+ if not args.text_only:
130
+ logging.info('Calling gpt4v API...')
131
+ openai_response = openai_client.chat.completions.create(
132
+ model=args.api_model, messages=messages, max_tokens=1000, seed=args.seed
133
+ )
134
+ else:
135
+ logging.info('Calling gpt4 API...')
136
+ openai_response = openai_client.chat.completions.create(
137
+ model=args.api_model, messages=messages, max_tokens=1000, seed=args.seed, timeout=30
138
+ )
139
+
140
+ prompt_tokens = openai_response.usage.prompt_tokens
141
+ completion_tokens = openai_response.usage.completion_tokens
142
+
143
+ logging.info(f'Prompt Tokens: {prompt_tokens}; Completion Tokens: {completion_tokens}')
144
+
145
+ gpt_call_error = False
146
+ return prompt_tokens, completion_tokens, gpt_call_error, openai_response
147
+
148
+ except Exception as e:
149
+ logging.info(f'Error occurred, retrying. Error type: {type(e).__name__}')
150
+
151
+ if type(e).__name__ == 'RateLimitError':
152
+ time.sleep(10)
153
+
154
+ elif type(e).__name__ == 'APIError':
155
+ time.sleep(15)
156
+
157
+ elif type(e).__name__ == 'InvalidRequestError':
158
+ gpt_call_error = True
159
+ return None, None, gpt_call_error, None
160
+
161
+ else:
162
+ gpt_call_error = True
163
+ return None, None, gpt_call_error, None
164
+
165
+ retry_times += 1
166
+ if retry_times == 10:
167
+ logging.info('Retrying too many times')
168
+ return None, None, True, None
169
+
170
+
171
+ def exec_action_click(info, web_ele, driver_task):
172
+ driver_task.execute_script("arguments[0].setAttribute('target', '_self')", web_ele)
173
+ web_ele.click()
174
+ time.sleep(3)
175
+
176
+
177
+ def exec_action_type(info, web_ele, driver_task):
178
+ warn_obs = ""
179
+ type_content = info['content']
180
+
181
+ ele_tag_name = web_ele.tag_name.lower()
182
+ ele_type = web_ele.get_attribute("type")
183
+ # outer_html = web_ele.get_attribute("outerHTML")
184
+ if (ele_tag_name != 'input' and ele_tag_name != 'textarea') or (ele_tag_name == 'input' and ele_type not in ['text', 'search', 'password', 'email', 'tel']):
185
+ warn_obs = f"note: The web element you're trying to type may not be a textbox, and its tag name is <{web_ele.tag_name}>, type is {ele_type}."
186
+ try:
187
+ # Not always work to delete
188
+ web_ele.clear()
189
+ # Another way to delete
190
+ if platform.system() == 'Darwin':
191
+ web_ele.send_keys(Keys.COMMAND + "a")
192
+ else:
193
+ web_ele.send_keys(Keys.CONTROL + "a")
194
+ web_ele.send_keys(" ")
195
+ web_ele.send_keys(Keys.BACKSPACE)
196
+ except:
197
+ pass
198
+
199
+ actions = ActionChains(driver_task)
200
+ actions.click(web_ele).perform()
201
+ actions.pause(1)
202
+
203
+ try:
204
+ driver_task.execute_script("""window.onkeydown = function(e) {if(e.keyCode == 32 && e.target.type != 'text' && e.target.type != 'textarea' && e.target.type != 'search') {e.preventDefault();}};""")
205
+ except:
206
+ pass
207
+
208
+ actions.send_keys(type_content)
209
+ actions.pause(2)
210
+
211
+ actions.send_keys(Keys.ENTER)
212
+ actions.perform()
213
+ time.sleep(10)
214
+ return warn_obs
215
+
216
+
217
+ def exec_action_scroll(info, web_eles, driver_task, args, obs_info):
218
+ scroll_ele_number = info['number']
219
+ scroll_content = info['content']
220
+ if scroll_ele_number == "WINDOW":
221
+ if scroll_content == 'down':
222
+ driver_task.execute_script(f"window.scrollBy(0, {args.window_height*2//3});")
223
+ else:
224
+ driver_task.execute_script(f"window.scrollBy(0, {-args.window_height*2//3});")
225
+ else:
226
+ if not args.text_only:
227
+ scroll_ele_number = int(scroll_ele_number)
228
+ web_ele = web_eles[scroll_ele_number]
229
+ else:
230
+ element_box = obs_info[scroll_ele_number]['union_bound']
231
+ element_box_center = (element_box[0] + element_box[2] // 2, element_box[1] + element_box[3] // 2)
232
+ web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
233
+ actions = ActionChains(driver_task)
234
+ driver_task.execute_script("arguments[0].focus();", web_ele)
235
+ if scroll_content == 'down':
236
+ actions.key_down(Keys.ALT).send_keys(Keys.ARROW_DOWN).key_up(Keys.ALT).perform()
237
+ else:
238
+ actions.key_down(Keys.ALT).send_keys(Keys.ARROW_UP).key_up(Keys.ALT).perform()
239
+ time.sleep(3)
240
+
241
+
242
+ def webvoyager_run(args, task, task_dir):
243
+ # OpenAI client
244
+ client = OpenAI(api_key=args.api_key)
245
+ options = driver_config(args)
246
+ setup_logger(task_dir)
247
+ logging.info(f'########## TASK{task["id"]} ##########')
248
+
249
+ driver_task = webdriver.Chrome(options=options)
250
+ driver_task.set_window_size(args.window_width, args.window_height)
251
+ driver_task.get(task['web'])
252
+ try:
253
+ driver_task.find_element(By.TAG_NAME, 'body').click()
254
+ except:
255
+ pass
256
+ driver_task.execute_script("""window.onkeydown = function(e) {if(e.keyCode == 32 && e.target.type != 'text' && e.target.type != 'textarea') {e.preventDefault();}};""")
257
+ time.sleep(5)
258
+
259
+ for filename in os.listdir(args.download_dir):
260
+ file_path = os.path.join(args.download_dir, filename)
261
+ if os.path.isfile(file_path):
262
+ os.remove(file_path)
263
+
264
+ download_files = []
265
+ fail_obs = ""
266
+ pdf_obs = ""
267
+ warn_obs = ""
268
+ pattern = r'Thought:|Action:|Observation:'
269
+
270
+ messages = [{'role': 'system', 'content': SYSTEM_PROMPT}]
271
+ obs_prompt = "Observation: please analyze the attached screenshot and give the Thought and Action. "
272
+ if args.text_only:
273
+ messages = [{'role': 'system', 'content': SYSTEM_PROMPT_TEXT_ONLY}]
274
+ obs_prompt = "Observation: please analyze the accessibility tree and give the Thought and Action."
275
+
276
+ init_msg = f"""Now given a task: {task['ques']} Please interact with https://www.example.com and get the answer. \n"""
277
+ init_msg = init_msg.replace('https://www.example.com', task['web'])
278
+ init_msg = init_msg + obs_prompt
279
+
280
+ it = 0
281
+ accumulate_prompt_token = 0
282
+ accumulate_completion_token = 0
283
+
284
+ while it < args.max_iter:
285
+ logging.info(f'Iter: {it}')
286
+ it += 1
287
+ if not fail_obs:
288
+ try:
289
+ if not args.text_only:
290
+ rects, web_eles, web_eles_text = get_web_element_rect(driver_task, fix_color=args.fix_box_color)
291
+ else:
292
+ accessibility_tree_path = os.path.join(task_dir, 'accessibility_tree{}'.format(it))
293
+ ac_tree, obs_info = get_webarena_accessibility_tree(driver_task, accessibility_tree_path)
294
+ except Exception as e:
295
+ if not args.text_only:
296
+ logging.error('Driver error when adding set-of-mark.')
297
+ else:
298
+ logging.error('Driver error when obtaining accessibility tree.')
299
+ logging.error(e)
300
+ break
301
+
302
+ img_path = os.path.join(task_dir, 'screenshot{}.png'.format(it))
303
+ driver_task.save_screenshot(img_path)
304
+
305
+ if (not args.text_only) and args.save_accessibility_tree:
306
+ accessibility_tree_path = os.path.join(task_dir, 'accessibility_tree{}'.format(it))
307
+ get_webarena_accessibility_tree(driver_task, accessibility_tree_path)
308
+
309
+ b64_img = encode_image(img_path)
310
+
311
+ if not args.text_only:
312
+ curr_msg = format_msg(it, init_msg, pdf_obs, warn_obs, b64_img, web_eles_text)
313
+ else:
314
+ curr_msg = format_msg_text_only(it, init_msg, pdf_obs, warn_obs, ac_tree)
315
+ messages.append(curr_msg)
316
+ else:
317
+ curr_msg = {
318
+ 'role': 'user',
319
+ 'content': fail_obs
320
+ }
321
+ messages.append(curr_msg)
322
+
323
+ if not args.text_only:
324
+ messages = clip_message_and_obs(messages, args.max_attached_imgs)
325
+ else:
326
+ messages = clip_message_and_obs_text_only(messages, args.max_attached_imgs)
327
+
328
+ prompt_tokens, completion_tokens, gpt_call_error, openai_response = call_gpt4v_api(args, client, messages)
329
+
330
+ if gpt_call_error:
331
+ break
332
+ else:
333
+ accumulate_prompt_token += prompt_tokens
334
+ accumulate_completion_token += completion_tokens
335
+ logging.info(f'Accumulate Prompt Tokens: {accumulate_prompt_token}; Accumulate Completion Tokens: {accumulate_completion_token}')
336
+ logging.info('API call complete...')
337
+ gpt_4v_res = openai_response.choices[0].message.content
338
+ messages.append({'role': 'assistant', 'content': gpt_4v_res})
339
+
340
+ if (not args.text_only) and 'rects' in locals() and rects:
341
+ logging.info(f"Num of interactive elements: {len(rects)}")
342
+ for rect_ele in rects:
343
+ driver_task.execute_script("arguments[0].remove()", rect_ele)
344
+ rects = []
345
+
346
+ try:
347
+ assert 'Thought:' in gpt_4v_res and 'Action:' in gpt_4v_res
348
+ except AssertionError as e:
349
+ logging.error(e)
350
+ fail_obs = "Format ERROR: Both 'Thought' and 'Action' should be included in your reply."
351
+ continue
352
+
353
+ chosen_action = re.split(pattern, gpt_4v_res)[2].strip()
354
+ action_key, info = extract_information(chosen_action)
355
+
356
+ fail_obs = ""
357
+ pdf_obs = ""
358
+ warn_obs = ""
359
+ try:
360
+ window_handle_task = driver_task.current_window_handle
361
+ driver_task.switch_to.window(window_handle_task)
362
+
363
+ if action_key == 'click':
364
+ if not args.text_only:
365
+ click_ele_number = int(info[0])
366
+ web_ele = web_eles[click_ele_number]
367
+ else:
368
+ click_ele_number = info[0]
369
+ element_box = obs_info[click_ele_number]['union_bound']
370
+ element_box_center = (element_box[0] + element_box[2] // 2,
371
+ element_box[1] + element_box[3] // 2)
372
+ web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
373
+
374
+ ele_tag_name = web_ele.tag_name.lower()
375
+ ele_type = web_ele.get_attribute("type")
376
+
377
+ exec_action_click(info, web_ele, driver_task)
378
+
379
+ current_files = sorted(os.listdir(args.download_dir))
380
+ if current_files != download_files:
381
+ time.sleep(10)
382
+ current_files = sorted(os.listdir(args.download_dir))
383
+ current_download_file = [pdf_file for pdf_file in current_files if pdf_file not in download_files and pdf_file.endswith('.pdf')]
384
+ if current_download_file:
385
+ pdf_file = current_download_file[0]
386
+ pdf_obs = get_pdf_retrieval_ans_from_assistant(client, os.path.join(args.download_dir, pdf_file), task['ques'])
387
+ shutil.copy(os.path.join(args.download_dir, pdf_file), task_dir)
388
+ pdf_obs = "You downloaded a PDF file, I ask the Assistant API to answer the task based on the PDF file and get the following response: " + pdf_obs
389
+ download_files = current_files
390
+
391
+ if ele_tag_name == 'button' and ele_type == 'submit':
392
+ time.sleep(10)
393
+
394
+ elif action_key == 'wait':
395
+ time.sleep(5)
396
+
397
+ elif action_key == 'type':
398
+ if not args.text_only:
399
+ type_ele_number = int(info['number'])
400
+ web_ele = web_eles[type_ele_number]
401
+ else:
402
+ type_ele_number = info['number']
403
+ element_box = obs_info[type_ele_number]['union_bound']
404
+ element_box_center = (element_box[0] + element_box[2] // 2,
405
+ element_box[1] + element_box[3] // 2)
406
+ web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
407
+ warn_obs = exec_action_type(info, web_ele, driver_task)
408
+ if 'wolfram' in task['web']:
409
+ time.sleep(5)
410
+
411
+ elif action_key == 'scroll':
412
+ if not args.text_only:
413
+ exec_action_scroll(info, web_eles, driver_task, args, None)
414
+ else:
415
+ exec_action_scroll(info, None, driver_task, args, obs_info)
416
+
417
+ elif action_key == 'goback':
418
+ driver_task.back()
419
+ time.sleep(2)
420
+
421
+ elif action_key == 'google':
422
+ driver_task.get('https://www.google.com/')
423
+ time.sleep(2)
424
+
425
+ elif action_key == 'answer':
426
+ logging.info(info['content'])
427
+ logging.info('finish!!')
428
+ break
429
+ else:
430
+ raise NotImplementedError
431
+ fail_obs = ""
432
+ except Exception as e:
433
+ logging.error('driver error info:')
434
+ logging.error(e)
435
+ if 'element click intercepted' not in str(e):
436
+ fail_obs = "The action you have chosen cannot be exected. Please double-check if you have selected the wrong Numerical Label or Action or Action format. Then provide the revised Thought and Action."
437
+ else:
438
+ fail_obs = ""
439
+ time.sleep(2)
440
+
441
+ print_message(messages, task_dir)
442
+ driver_task.quit()
443
+ logging.info(f'Total cost: {accumulate_prompt_token / 1000 * 0.01 + accumulate_completion_token / 1000 * 0.03}')
444
+
445
+
446
+ def main():
447
+ parser = argparse.ArgumentParser()
448
+ parser.add_argument('--test_file', type=str, default='data/test.json')
449
+ parser.add_argument('--max_iter', type=int, default=5)
450
+ parser.add_argument("--api_key", default="key", type=str, help="YOUR_OPENAI_API_KEY")
451
+ parser.add_argument("--api_model", default="gpt-4-vision-preview", type=str, help="api model name")
452
+ parser.add_argument("--output_dir", type=str, default='results')
453
+ parser.add_argument("--seed", type=int, default=None)
454
+ parser.add_argument("--max_attached_imgs", type=int, default=1)
455
+ parser.add_argument("--temperature", type=float, default=1.0)
456
+ parser.add_argument("--download_dir", type=str, default="downloads")
457
+ parser.add_argument("--text_only", action='store_true')
458
+ # for web browser
459
+ parser.add_argument("--headless", action='store_true', help='The window of selenium')
460
+ parser.add_argument("--save_accessibility_tree", action='store_true')
461
+ parser.add_argument("--force_device_scale", action='store_true')
462
+ parser.add_argument("--window_width", type=int, default=1024)
463
+ parser.add_argument("--window_height", type=int, default=768)
464
+ parser.add_argument("--fix_box_color", action='store_true')
465
+
466
+ args = parser.parse_args()
467
+
468
+ # Save Result file
469
+ current_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime())
470
+ result_dir = os.path.join(args.output_dir, current_time)
471
+ os.makedirs(result_dir, exist_ok=True)
472
+
473
+ # Load tasks
474
+ tasks = []
475
+ with open(args.test_file, 'r', encoding='utf-8') as f:
476
+ for line in f:
477
+ tasks.append(json.loads(line))
478
+
479
+ for task_id in range(len(tasks)):
480
+ task = tasks[task_id]
481
+ task_dir = os.path.join(result_dir, 'task{}'.format(task["id"]))
482
+ os.makedirs(task_dir, exist_ok=True)
483
+ webvoyager_run(args, task, task_dir)
484
+
485
+
486
+ if __name__ == '__main__':
487
+ main()
488
+ print('End of process')
utils.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import re
3
+ import os
4
+ import json
5
+ import time
6
+ import logging
7
+ import numpy as np
8
+ from PIL import Image
9
+ from utils_webarena import fetch_browser_info, fetch_page_accessibility_tree,\
10
+ parse_accessibility_tree, clean_accesibility_tree
11
+
12
+
13
+ def resize_image(image_path):
14
+ image = Image.open(image_path)
15
+ width, height = image.size
16
+
17
+ if min(width, height) < 512:
18
+ return image
19
+ elif width < height:
20
+ new_width = 512
21
+ new_height = int(height * (new_width / width))
22
+ else:
23
+ new_height = 512
24
+ new_width = int(width * (new_height / height))
25
+
26
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
27
+ resized_image.save(image_path)
28
+ # return resized_image
29
+
30
+
31
+ # base64 encoding
32
+ # Code from OpenAI Document
33
+ def encode_image(image_path):
34
+ with open(image_path, "rb") as image_file:
35
+ return base64.b64encode(image_file.read()).decode('utf-8')
36
+
37
+
38
+ # interact with webpage and add rectangles on elements
39
+ def get_web_element_rect(browser, fix_color=True):
40
+ if fix_color:
41
+ selected_function = "getFixedColor"
42
+ # color_you_like = '#5210da'
43
+ else:
44
+ selected_function = "getRandomColor"
45
+
46
+ js_script = """
47
+ let labels = [];
48
+
49
+ function markPage() {
50
+ var bodyRect = document.body.getBoundingClientRect();
51
+
52
+ var items = Array.prototype.slice.call(
53
+ document.querySelectorAll('*')
54
+ ).map(function(element) {
55
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
56
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
57
+
58
+ var rects = [...element.getClientRects()].filter(bb => {
59
+ var center_x = bb.left + bb.width / 2;
60
+ var center_y = bb.top + bb.height / 2;
61
+ var elAtCenter = document.elementFromPoint(center_x, center_y);
62
+
63
+ return elAtCenter === element || element.contains(elAtCenter)
64
+ }).map(bb => {
65
+ const rect = {
66
+ left: Math.max(0, bb.left),
67
+ top: Math.max(0, bb.top),
68
+ right: Math.min(vw, bb.right),
69
+ bottom: Math.min(vh, bb.bottom)
70
+ };
71
+ return {
72
+ ...rect,
73
+ width: rect.right - rect.left,
74
+ height: rect.bottom - rect.top
75
+ }
76
+ });
77
+
78
+ var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
79
+
80
+ return {
81
+ element: element,
82
+ include:
83
+ (element.tagName === "INPUT" || element.tagName === "TEXTAREA" || element.tagName === "SELECT") ||
84
+ (element.tagName === "BUTTON" || element.tagName === "A" || (element.onclick != null) || window.getComputedStyle(element).cursor == "pointer") ||
85
+ (element.tagName === "IFRAME" || element.tagName === "VIDEO" || element.tagName === "LI" || element.tagName === "TD" || element.tagName === "OPTION")
86
+ ,
87
+ area,
88
+ rects,
89
+ text: element.textContent.trim().replace(/\s{2,}/g, ' ')
90
+ };
91
+ }).filter(item =>
92
+ item.include && (item.area >= 20)
93
+ );
94
+
95
+ // Only keep inner clickable items
96
+ // first delete button inner clickable items
97
+ const buttons = Array.from(document.querySelectorAll('button, a, input[type="button"], div[role="button"]'));
98
+
99
+ //items = items.filter(x => !buttons.some(y => y.contains(x.element) && !(x.element === y) ));
100
+ items = items.filter(x => !buttons.some(y => items.some(z => z.element === y) && y.contains(x.element) && !(x.element === y) ));
101
+ items = items.filter(x =>
102
+ !(x.element.parentNode &&
103
+ x.element.parentNode.tagName === 'SPAN' &&
104
+ x.element.parentNode.children.length === 1 &&
105
+ x.element.parentNode.getAttribute('role') &&
106
+ items.some(y => y.element === x.element.parentNode)));
107
+
108
+ items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
109
+
110
+ // Function to generate random colors
111
+ function getRandomColor(index) {
112
+ var letters = '0123456789ABCDEF';
113
+ var color = '#';
114
+ for (var i = 0; i < 6; i++) {
115
+ color += letters[Math.floor(Math.random() * 16)];
116
+ }
117
+ return color;
118
+ }
119
+
120
+ function getFixedColor(index) {
121
+ var color = '#000000'
122
+ return color
123
+ }
124
+ //function getFixedColor(index){
125
+ // var colors = ['#FF0000', '#00FF00', '#0000FF', '#000000']; // Red, Green, Blue, Black
126
+ // return colors[index % 4];
127
+ //}
128
+
129
+
130
+ // Lets create a floating border on top of these elements that will always be visible
131
+ items.forEach(function(item, index) {
132
+ item.rects.forEach((bbox) => {
133
+ newElement = document.createElement("div");
134
+ var borderColor = COLOR_FUNCTION(index);
135
+ newElement.style.outline = `2px dashed ${borderColor}`;
136
+ newElement.style.position = "fixed";
137
+ newElement.style.left = bbox.left + "px";
138
+ newElement.style.top = bbox.top + "px";
139
+ newElement.style.width = bbox.width + "px";
140
+ newElement.style.height = bbox.height + "px";
141
+ newElement.style.pointerEvents = "none";
142
+ newElement.style.boxSizing = "border-box";
143
+ newElement.style.zIndex = 2147483647;
144
+ // newElement.style.background = `${borderColor}80`;
145
+
146
+ // Add floating label at the corner
147
+ var label = document.createElement("span");
148
+ label.textContent = index;
149
+ label.style.position = "absolute";
150
+ //label.style.top = "-19px";
151
+ label.style.top = Math.max(-19, -bbox.top) + "px";
152
+ //label.style.left = "0px";
153
+ label.style.left = Math.min(Math.floor(bbox.width / 5), 2) + "px";
154
+ label.style.background = borderColor;
155
+ label.style.color = "white";
156
+ label.style.padding = "2px 4px";
157
+ label.style.fontSize = "12px";
158
+ label.style.borderRadius = "2px";
159
+ newElement.appendChild(label);
160
+
161
+ document.body.appendChild(newElement);
162
+ labels.push(newElement);
163
+ // item.element.setAttribute("-ai-label", label.textContent);
164
+ });
165
+ })
166
+
167
+ // For the first way
168
+ // return [labels, items.map(item => ({
169
+ // rect: item.rects[0] // assuming there's at least one rect
170
+ // }))];
171
+
172
+ // For the second way
173
+ return [labels, items]
174
+ }
175
+ return markPage();""".replace("COLOR_FUNCTION", selected_function)
176
+ rects, items_raw = browser.execute_script(js_script)
177
+
178
+ # format_ele_text = [f"[{web_ele_id}]: \"{items_raw[web_ele_id]['text']}\";" for web_ele_id in range(len(items_raw)) if items_raw[web_ele_id]['text'] ]
179
+ format_ele_text = []
180
+ for web_ele_id in range(len(items_raw)):
181
+ label_text = items_raw[web_ele_id]['text']
182
+ ele_tag_name = items_raw[web_ele_id]['element'].tag_name
183
+ ele_type = items_raw[web_ele_id]['element'].get_attribute("type")
184
+ ele_aria_label = items_raw[web_ele_id]['element'].get_attribute("aria-label")
185
+ input_attr_types = ['text', 'search', 'password', 'email', 'tel']
186
+
187
+ if not label_text:
188
+ if (ele_tag_name.lower() == 'input' and ele_type in input_attr_types) or ele_tag_name.lower() == 'textarea' or (ele_tag_name.lower() == 'button' and ele_type in ['submit', 'button']):
189
+ if ele_aria_label:
190
+ format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{ele_aria_label}\";")
191
+ else:
192
+ format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\";" )
193
+
194
+ elif label_text and len(label_text) < 200:
195
+ if not ("<img" in label_text and "src=" in label_text):
196
+ if ele_tag_name in ["button", "input", "textarea"]:
197
+ if ele_aria_label and (ele_aria_label != label_text):
198
+ format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\", \"{ele_aria_label}\";")
199
+ else:
200
+ format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\";")
201
+ else:
202
+ if ele_aria_label and (ele_aria_label != label_text):
203
+ format_ele_text.append(f"[{web_ele_id}]: \"{label_text}\", \"{ele_aria_label}\";")
204
+ else:
205
+ format_ele_text.append(f"[{web_ele_id}]: \"{label_text}\";")
206
+
207
+
208
+
209
+ format_ele_text = '\t'.join(format_ele_text)
210
+ return rects, [web_ele['element'] for web_ele in items_raw], format_ele_text
211
+
212
+
213
+ def extract_information(text):
214
+ patterns = {
215
+ "click": r"Click \[?(\d+)\]?",
216
+ "type": r"Type \[?(\d+)\]?[; ]+\[?(.[^\]]*)\]?",
217
+ # "delete_and_type": r"Delete_and_Type \[?(\d+)\]?[; ]+\[?(.[^\]]*)\]?",
218
+ "scroll": r"Scroll \[?(\d+|WINDOW)\]?[; ]+\[?(up|down)\]?",
219
+ "wait": r"^Wait",
220
+ "goback": r"^GoBack",
221
+ "google": r"^Google",
222
+ "answer": r"ANSWER[; ]+\[?(.[^\]]*)\]?"
223
+ }
224
+
225
+ for key, pattern in patterns.items():
226
+ match = re.search(pattern, text)
227
+ if match:
228
+ if key in ["click", "wait", "goback", "google"]:
229
+ # no content
230
+ return key, match.groups()
231
+ else:
232
+ return key, {"number": match.group(1), "content": match.group(2)} if key in ["type", "scroll"] else {"content": match.group(1)}
233
+ return None, None
234
+
235
+
236
+ def clip_message(msg, max_img_num):
237
+ clipped_msg = []
238
+ img_num = 0
239
+ for idx in range(len(msg)):
240
+ curr_msg = msg[len(msg) - 1 - idx]
241
+ if curr_msg['role'] != 'user':
242
+ clipped_msg = [curr_msg] + clipped_msg
243
+ else:
244
+ if type(curr_msg['content']) == str:
245
+ clipped_msg = [curr_msg] + clipped_msg
246
+ elif img_num < max_img_num:
247
+ img_num += 1
248
+ clipped_msg = [curr_msg] + clipped_msg
249
+ else:
250
+ curr_msg_clip = {
251
+ 'role': curr_msg['role'],
252
+ 'content': curr_msg['content'][0]["text"]
253
+ }
254
+ clipped_msg = [curr_msg_clip] + clipped_msg
255
+ return clipped_msg
256
+
257
+
258
+ def clip_message_and_obs(msg, max_img_num):
259
+ clipped_msg = []
260
+ img_num = 0
261
+ for idx in range(len(msg)):
262
+ curr_msg = msg[len(msg) - 1 - idx]
263
+ if curr_msg['role'] != 'user':
264
+ clipped_msg = [curr_msg] + clipped_msg
265
+ else:
266
+ if type(curr_msg['content']) == str:
267
+ clipped_msg = [curr_msg] + clipped_msg
268
+ elif img_num < max_img_num:
269
+ img_num += 1
270
+ clipped_msg = [curr_msg] + clipped_msg
271
+ else:
272
+ msg_no_pdf = curr_msg['content'][0]["text"].split("Observation:")[0].strip() + "Observation: A screenshot and some texts. (Omitted in context.)"
273
+ msg_pdf = curr_msg['content'][0]["text"].split("Observation:")[0].strip() + "Observation: A screenshot, a PDF file and some texts. (Omitted in context.)"
274
+ curr_msg_clip = {
275
+ 'role': curr_msg['role'],
276
+ 'content': msg_no_pdf if "You downloaded a PDF file" not in curr_msg['content'][0]["text"] else msg_pdf
277
+ }
278
+ clipped_msg = [curr_msg_clip] + clipped_msg
279
+ return clipped_msg
280
+
281
+
282
+ def clip_message_and_obs_text_only(msg, max_tree_num):
283
+ clipped_msg = []
284
+ tree_num = 0
285
+ for idx in range(len(msg)):
286
+ curr_msg = msg[len(msg) - 1 - idx]
287
+ if curr_msg['role'] != 'user':
288
+ clipped_msg = [curr_msg] + clipped_msg
289
+ else:
290
+ if tree_num < max_tree_num:
291
+ tree_num += 1
292
+ clipped_msg = [curr_msg] + clipped_msg
293
+ else:
294
+ msg_no_pdf = curr_msg['content'].split("Observation:")[0].strip() + "Observation: An accessibility tree. (Omitted in context.)"
295
+ msg_pdf = curr_msg['content'].split("Observation:")[0].strip() + "Observation: An accessibility tree and a PDF file. (Omitted in context.)"
296
+ curr_msg_clip = {
297
+ 'role': curr_msg['role'],
298
+ 'content': msg_no_pdf if "You downloaded a PDF file" not in curr_msg['content'] else msg_pdf
299
+ }
300
+ clipped_msg = [curr_msg_clip] + clipped_msg
301
+ return clipped_msg
302
+
303
+
304
+ def print_message(json_object, save_dir=None):
305
+ remove_b64code_obj = []
306
+ for obj in json_object:
307
+ if obj['role'] != 'user':
308
+ # print(obj)
309
+ logging.info(obj)
310
+ remove_b64code_obj.append(obj)
311
+ else:
312
+ if type(obj['content']) == str:
313
+ # print(obj)
314
+ logging.info(obj)
315
+ remove_b64code_obj.append(obj)
316
+ else:
317
+ print_obj = {
318
+ 'role': obj['role'],
319
+ 'content': obj['content']
320
+ }
321
+ for item in print_obj['content']:
322
+ if item['type'] == 'image_url':
323
+ item['image_url'] = {"url": "data:image/png;base64,{b64_img}"}
324
+ # print(print_obj)
325
+ logging.info(print_obj)
326
+ remove_b64code_obj.append(print_obj)
327
+ if save_dir:
328
+ with open(os.path.join(save_dir, 'interact_messages.json'), 'w', encoding='utf-8') as fw:
329
+ json.dump(remove_b64code_obj, fw, indent=2)
330
+ # return remove_b64code_obj
331
+
332
+
333
+ def get_webarena_accessibility_tree(browser, save_file=None):
334
+ browser_info = fetch_browser_info(browser)
335
+ accessibility_tree = fetch_page_accessibility_tree(browser_info, browser, current_viewport_only=True)
336
+ content, obs_nodes_info = parse_accessibility_tree(accessibility_tree)
337
+ content = clean_accesibility_tree(content)
338
+ if save_file:
339
+ with open(save_file + '.json', 'w', encoding='utf-8') as fw:
340
+ json.dump(obs_nodes_info, fw, indent=2)
341
+ with open(save_file + '.txt', 'w', encoding='utf-8') as fw:
342
+ fw.write(content)
343
+
344
+
345
+ return content, obs_nodes_info
346
+
347
+
348
+ def compare_images(img1_path, img2_path):
349
+ img1 = Image.open(img1_path)
350
+ img2 = Image.open(img2_path)
351
+
352
+ img1_array = np.asarray(img1)
353
+ img2_array = np.asarray(img2)
354
+
355
+ difference = np.abs(img1_array - img2_array)
356
+
357
+ total_difference = np.sum(difference)
358
+
359
+ return total_difference
360
+
361
+
362
+ def get_pdf_retrieval_ans_from_assistant(client, pdf_path, task):
363
+ # print("You download a PDF file that will be retrieved using the Assistant API.")
364
+ logging.info("You download a PDF file that will be retrieved using the Assistant API.")
365
+ file = client.files.create(
366
+ file=open(pdf_path, "rb"),
367
+ purpose='assistants'
368
+ )
369
+ # print("Create assistant...")
370
+ logging.info("Create assistant...")
371
+ assistant = client.beta.assistants.create(
372
+ instructions="You are a helpful assistant that can analyze the content of a PDF file and give an answer that matches the given task, or retrieve relevant content that matches the task.",
373
+ model="gpt-4-1106-preview",
374
+ tools=[{"type": "retrieval"}],
375
+ file_ids=[file.id]
376
+ )
377
+ thread = client.beta.threads.create()
378
+ message = client.beta.threads.messages.create(
379
+ thread_id=thread.id,
380
+ role="user",
381
+ content=task,
382
+ file_ids=[file.id]
383
+ )
384
+ run = client.beta.threads.runs.create(
385
+ thread_id=thread.id,
386
+ assistant_id=assistant.id
387
+ )
388
+ while True:
389
+ # Retrieve the run status
390
+ run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
391
+ if run_status.status == 'completed':
392
+ break
393
+ time.sleep(2)
394
+ messages = client.beta.threads.messages.list(thread_id=thread.id)
395
+ messages_text = messages.data[0].content[0].text.value
396
+ file_deletion_status = client.beta.assistants.files.delete(
397
+ assistant_id=assistant.id,
398
+ file_id=file.id
399
+ )
400
+ # print(file_deletion_status)
401
+ logging.info(file_deletion_status)
402
+ assistant_deletion_status = client.beta.assistants.delete(assistant.id)
403
+ # print(assistant_deletion_status)
404
+ logging.info(assistant_deletion_status)
405
+ return messages_text