Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- README.md +15 -6
- app.py +97 -0
- assets/icon.png +3 -0
- assets/overall_process_crop.png +3 -0
- assets/webvoyager_overall_res.png +3 -0
- packages.txt +1 -0
- prompts.py +78 -0
- requirements.txt +4 -0
- run.py +488 -0
- utils.py +405 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/icon.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/overall_process_crop.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/webvoyager_overall_res.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,21 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: WebVoyager
|
| 3 |
+
emoji: 🚀
|
| 4 |
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.12.0
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# WebVoyager
|
| 12 |
+
|
| 13 |
+
This is a Gradio demo of the WebVoyager agent, which can complete user instructions end-to-end by interacting with real-world websites.
|
| 14 |
+
|
| 15 |
+
## How to use
|
| 16 |
+
|
| 17 |
+
1. Enter the URL of the website you want the agent to interact with.
|
| 18 |
+
2. Describe the task you want the agent to perform.
|
| 19 |
+
3. Click "Submit" and watch the agent work in real-time.
|
| 20 |
+
|
| 21 |
+
**Note:** This is a research demo and may not always succeed. The agent's performance depends on the complexity of the task and the website's structure.
|
app.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
import threading
|
| 7 |
+
from run import webvoyager_run
|
| 8 |
+
|
| 9 |
+
def run_script_for_gradio(url, task):
|
| 10 |
+
"""
|
| 11 |
+
A wrapper to run the webvoyager script for Gradio, capturing output.
|
| 12 |
+
"""
|
| 13 |
+
# Create a temporary directory for this run
|
| 14 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 15 |
+
# Create a temporary file for the task
|
| 16 |
+
task_file_path = os.path.join(temp_dir, 'task.jsonl')
|
| 17 |
+
with open(task_file_path, 'w') as f:
|
| 18 |
+
f.write(f'{{"id": "custom_task", "web": "{url}", "ques": "{task}"}}')
|
| 19 |
+
|
| 20 |
+
# Setup arguments
|
| 21 |
+
args = argparse.Namespace(
|
| 22 |
+
test_file=task_file_path,
|
| 23 |
+
max_iter=5,
|
| 24 |
+
api_key=os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
|
| 25 |
+
api_model="gpt-4-vision-preview",
|
| 26 |
+
output_dir=os.path.join(temp_dir, 'results'),
|
| 27 |
+
seed=None,
|
| 28 |
+
max_attached_imgs=1,
|
| 29 |
+
temperature=1.0,
|
| 30 |
+
download_dir=os.path.join(temp_dir, 'downloads'),
|
| 31 |
+
text_only=False,
|
| 32 |
+
headless=True,
|
| 33 |
+
save_accessibility_tree=False,
|
| 34 |
+
force_device_scale=False,
|
| 35 |
+
window_width=1024,
|
| 36 |
+
window_height=768,
|
| 37 |
+
fix_box_color=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Create necessary directories
|
| 41 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 42 |
+
os.makedirs(args.download_dir, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# Capture stdout
|
| 45 |
+
original_stdout = sys.stdout
|
| 46 |
+
sys.stdout = captured_output = threading.local()
|
| 47 |
+
captured_output.value = ""
|
| 48 |
+
|
| 49 |
+
def stream_catcher():
|
| 50 |
+
while True:
|
| 51 |
+
try:
|
| 52 |
+
# This is a bit of a hack to get real-time output
|
| 53 |
+
# A better solution would be to modify the logger to write to a queue
|
| 54 |
+
with open(os.path.join(args.output_dir, 'taskcustom_task', 'agent.log'), 'r') as f:
|
| 55 |
+
captured_output.value = f.read()
|
| 56 |
+
except FileNotFoundError:
|
| 57 |
+
pass # Log file might not be created yet
|
| 58 |
+
|
| 59 |
+
stream_thread = threading.Thread(target=stream_catcher)
|
| 60 |
+
stream_thread.daemon = True
|
| 61 |
+
stream_thread.start()
|
| 62 |
+
|
| 63 |
+
output = ""
|
| 64 |
+
try:
|
| 65 |
+
with open(task_file_path, 'r', encoding='utf-8') as f:
|
| 66 |
+
task_data = [json.loads(line) for line in f]
|
| 67 |
+
|
| 68 |
+
for task_item in task_data:
|
| 69 |
+
task_dir = os.path.join(args.output_dir, f'task{task_item["id"]}')
|
| 70 |
+
os.makedirs(task_dir, exist_ok=True)
|
| 71 |
+
webvoyager_run(args, task_item, task_dir)
|
| 72 |
+
|
| 73 |
+
# Read the final log
|
| 74 |
+
with open(os.path.join(task_dir, 'agent.log'), 'r') as f:
|
| 75 |
+
output = f.read()
|
| 76 |
+
yield output
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
output = f"An error occurred: {e}"
|
| 80 |
+
yield output
|
| 81 |
+
finally:
|
| 82 |
+
sys.stdout = original_stdout
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
iface = gr.Interface(
|
| 86 |
+
fn=run_script_for_gradio,
|
| 87 |
+
inputs=[
|
| 88 |
+
gr.Textbox(label="URL", placeholder="Enter the URL of the website"),
|
| 89 |
+
gr.Textbox(label="Task", placeholder="Describe the task to perform"),
|
| 90 |
+
],
|
| 91 |
+
outputs=gr.Textbox(label="Agent Output", lines=20, interactive=False),
|
| 92 |
+
title="WebVoyager",
|
| 93 |
+
description="An LMM-powered web agent that can complete user instructions end-to-end.",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
iface.launch()
|
assets/icon.png
ADDED
|
|
Git LFS Details
|
assets/overall_process_crop.png
ADDED
|
Git LFS Details
|
assets/webvoyager_overall_res.png
ADDED
|
Git LFS Details
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
chromium-driver
|
prompts.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SYSTEM_PROMPT = """Imagine you are a robot browsing the web, just like humans. Now you need to complete a task. In each iteration, you will receive an Observation that includes a screenshot of a webpage and some texts. This screenshot will feature Numerical Labels placed in the TOP LEFT corner of each Web Element.
|
| 2 |
+
Carefully analyze the visual information to identify the Numerical Label corresponding to the Web Element that requires interaction, then follow the guidelines and choose one of the following actions:
|
| 3 |
+
1. Click a Web Element.
|
| 4 |
+
2. Delete existing content in a textbox and then type content.
|
| 5 |
+
3. Scroll up or down. Multiple scrolls are allowed to browse the webpage. Pay attention!! The default scroll is the whole window. If the scroll widget is located in a certain area of the webpage, then you have to specify a Web Element in that area. I would hover the mouse there and then scroll.
|
| 6 |
+
4. Wait. Typically used to wait for unfinished webpage processes, with a duration of 5 seconds.
|
| 7 |
+
5. Go back, returning to the previous webpage.
|
| 8 |
+
6. Google, directly jump to the Google search page. When you can't find information in some websites, try starting over with Google.
|
| 9 |
+
7. Answer. This action should only be chosen when all questions in the task have been solved.
|
| 10 |
+
|
| 11 |
+
Correspondingly, Action should STRICTLY follow the format:
|
| 12 |
+
- Click [Numerical_Label]
|
| 13 |
+
- Type [Numerical_Label]; [Content]
|
| 14 |
+
- Scroll [Numerical_Label or WINDOW]; [up or down]
|
| 15 |
+
- Wait
|
| 16 |
+
- GoBack
|
| 17 |
+
- Google
|
| 18 |
+
- ANSWER; [content]
|
| 19 |
+
|
| 20 |
+
Key Guidelines You MUST follow:
|
| 21 |
+
* Action guidelines *
|
| 22 |
+
1) To input text, NO need to click textbox first, directly type content. After typing, the system automatically hits `ENTER` key. Sometimes you should click the search button to apply search filters. Try to use simple language when searching.
|
| 23 |
+
2) You must Distinguish between textbox and search button, don't type content into the button! If no textbox is found, you may need to click the search button first before the textbox is displayed.
|
| 24 |
+
3) Execute only one action per iteration.
|
| 25 |
+
4) STRICTLY Avoid repeating the same action if the webpage remains unchanged. You may have selected the wrong web element or numerical label. Continuous use of the Wait is also NOT allowed.
|
| 26 |
+
5) When a complex Task involves multiple questions or steps, select "ANSWER" only at the very end, after addressing all of these questions (steps). Flexibly combine your own abilities with the information in the web page. Double check the formatting requirements in the task when ANSWER.
|
| 27 |
+
* Web Browsing Guidelines *
|
| 28 |
+
1) Don't interact with useless web elements like Login, Sign-in, donation that appear in Webpages. Pay attention to Key Web Elements like search textbox and menu.
|
| 29 |
+
2) Vsit video websites like YouTube is allowed BUT you can't play videos. Clicking to download PDF is allowed and will be analyzed by the Assistant API.
|
| 30 |
+
3) Focus on the numerical labels in the TOP LEFT corner of each rectangle (element). Ensure you don't mix them up with other numbers (e.g. Calendar) on the page.
|
| 31 |
+
4) Focus on the date in task, you must look for results that match the date. It may be necessary to find the correct year, month and day at calendar.
|
| 32 |
+
5) Pay attention to the filter and sort functions on the page, which, combined with scroll, can help you solve conditions like 'highest', 'cheapest', 'lowest', 'earliest', etc. Try your best to find the answer that best fits the task.
|
| 33 |
+
|
| 34 |
+
Your reply should strictly follow the format:
|
| 35 |
+
Thought: {Your brief thoughts (briefly summarize the info that will help ANSWER)}
|
| 36 |
+
Action: {One Action format you choose}
|
| 37 |
+
|
| 38 |
+
Then the User will provide:
|
| 39 |
+
Observation: {A labeled screenshot Given by User}"""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
SYSTEM_PROMPT_TEXT_ONLY = """Imagine you are a robot browsing the web, just like humans. Now you need to complete a task. In each iteration, you will receive an Accessibility Tree with numerical label representing information about the page, then follow the guidelines and choose one of the following actions:
|
| 43 |
+
1. Click a Web Element.
|
| 44 |
+
2. Delete existing content in a textbox and then type content.
|
| 45 |
+
3. Scroll up or down. Multiple scrolls are allowed to browse the webpage. Pay attention!! The default scroll is the whole window. If the scroll widget is located in a certain area of the webpage, then you have to specify a Web Element in that area. I would hover the mouse there and then scroll.
|
| 46 |
+
4. Wait. Typically used to wait for unfinished webpage processes, with a duration of 5 seconds.
|
| 47 |
+
5. Go back, returning to the previous webpage.
|
| 48 |
+
6. Google, directly jump to the Google search page. When you can't find information in some websites, try starting over with Google.
|
| 49 |
+
7. Answer. This action should only be chosen when all questions in the task have been solved.
|
| 50 |
+
|
| 51 |
+
Correspondingly, Action should STRICTLY follow the format:
|
| 52 |
+
- Click [Numerical_Label]
|
| 53 |
+
- Type [Numerical_Label]; [Content]
|
| 54 |
+
- Scroll [Numerical_Label or WINDOW]; [up or down]
|
| 55 |
+
- Wait
|
| 56 |
+
- GoBack
|
| 57 |
+
- Google
|
| 58 |
+
- ANSWER; [content]
|
| 59 |
+
|
| 60 |
+
Key Guidelines You MUST follow:
|
| 61 |
+
* Action guidelines *
|
| 62 |
+
1) To input text, NO need to click textbox first, directly type content. After typing, the system automatically hits `ENTER` key. Sometimes you should click the search button to apply search filters. Try to use simple language when searching.
|
| 63 |
+
2) You must Distinguish between textbox and search button, don't type content into the button! If no textbox is found, you may need to click the search button first before the textbox is displayed.
|
| 64 |
+
3) Execute only one action per iteration.
|
| 65 |
+
4) STRICTLY Avoid repeating the same action if the webpage remains unchanged. You may have selected the wrong web element or numerical label. Continuous use of the Wait is also NOT allowed.
|
| 66 |
+
5) When a complex Task involves multiple questions or steps, select "ANSWER" only at the very end, after addressing all of these questions (steps). Flexibly combine your own abilities with the information in the web page. Double check the formatting requirements in the task when ANSWER.
|
| 67 |
+
* Web Browsing Guidelines *
|
| 68 |
+
1) Don't interact with useless web elements like Login, Sign-in, donation that appear in Webpages. Pay attention to Key Web Elements like search textbox and menu.
|
| 69 |
+
2) Vsit video websites like YouTube is allowed BUT you can't play videos. Clicking to download PDF is allowed and will be analyzed by the Assistant API.
|
| 70 |
+
3) Focus on the date in task, you must look for results that match the date. It may be necessary to find the correct year, month and day at calendar.
|
| 71 |
+
4) Pay attention to the filter and sort functions on the page, which, combined with scroll, can help you solve conditions like 'highest', 'cheapest', 'lowest', 'earliest', etc. Try your best to find the answer that best fits the task.
|
| 72 |
+
|
| 73 |
+
Your reply should strictly follow the format:
|
| 74 |
+
Thought: {Your brief thoughts (briefly summarize the info that will help ANSWER)}
|
| 75 |
+
Action: {One Action format you choose}
|
| 76 |
+
|
| 77 |
+
Then the User will provide:
|
| 78 |
+
Observation: {Accessibility Tree of a web page}"""
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openai==1.1.1
|
| 2 |
+
selenium==4.15.2
|
| 3 |
+
pillow==10.1.0
|
| 4 |
+
gradio
|
run.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from selenium import webdriver
|
| 11 |
+
from selenium.webdriver.common.by import By
|
| 12 |
+
from selenium.webdriver.common.keys import Keys
|
| 13 |
+
from selenium.webdriver.common.action_chains import ActionChains
|
| 14 |
+
|
| 15 |
+
from prompts import SYSTEM_PROMPT, SYSTEM_PROMPT_TEXT_ONLY
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
from utils import get_web_element_rect, encode_image, extract_information, print_message,\
|
| 18 |
+
get_webarena_accessibility_tree, get_pdf_retrieval_ans_from_assistant, clip_message_and_obs, clip_message_and_obs_text_only
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def setup_logger(folder_path):
|
| 22 |
+
log_file_path = os.path.join(folder_path, 'agent.log')
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger()
|
| 25 |
+
for handler in logger.handlers[:]:
|
| 26 |
+
logger.removeHandler(handler)
|
| 27 |
+
handler.close()
|
| 28 |
+
|
| 29 |
+
# Capture all logs to a file
|
| 30 |
+
handler = logging.FileHandler(log_file_path)
|
| 31 |
+
formatter = logging.Formatter('%(levelname)s - %(message)s')
|
| 32 |
+
handler.setFormatter(formatter)
|
| 33 |
+
logger.addHandler(handler)
|
| 34 |
+
logger.setLevel(logging.INFO)
|
| 35 |
+
|
| 36 |
+
# Also stream logs to stdout for real-time viewing in Gradio
|
| 37 |
+
stream_handler = logging.StreamHandler()
|
| 38 |
+
stream_handler.setFormatter(formatter)
|
| 39 |
+
logger.addHandler(stream_handler)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def driver_config(args):
|
| 43 |
+
options = webdriver.ChromeOptions()
|
| 44 |
+
|
| 45 |
+
if args.save_accessibility_tree:
|
| 46 |
+
args.force_device_scale = True
|
| 47 |
+
|
| 48 |
+
if args.force_device_scale:
|
| 49 |
+
options.add_argument("--force-device-scale-factor=1")
|
| 50 |
+
if args.headless:
|
| 51 |
+
options.add_argument("--headless")
|
| 52 |
+
options.add_argument("--no-sandbox") # Required for Hugging Face Spaces
|
| 53 |
+
options.add_argument("--disable-dev-shm-usage") # Required for Hugging Face Spaces
|
| 54 |
+
options.add_argument(
|
| 55 |
+
"--user-agent=Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36"
|
| 56 |
+
)
|
| 57 |
+
options.add_experimental_option(
|
| 58 |
+
"prefs", {
|
| 59 |
+
"download.default_directory": args.download_dir,
|
| 60 |
+
"plugins.always_open_pdf_externally": True
|
| 61 |
+
}
|
| 62 |
+
)
|
| 63 |
+
return options
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def format_msg(it, init_msg, pdf_obs, warn_obs, web_img_b64, web_text):
|
| 67 |
+
if it == 1:
|
| 68 |
+
init_msg += f"I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"
|
| 69 |
+
init_msg_format = {
|
| 70 |
+
'role': 'user',
|
| 71 |
+
'content': [
|
| 72 |
+
{'type': 'text', 'text': init_msg},
|
| 73 |
+
]
|
| 74 |
+
}
|
| 75 |
+
init_msg_format['content'].append({"type": "image_url",
|
| 76 |
+
"image_url": {"url": f"data:image/png;base64,{web_img_b64}"}})
|
| 77 |
+
return init_msg_format
|
| 78 |
+
else:
|
| 79 |
+
if not pdf_obs:
|
| 80 |
+
curr_msg = {
|
| 81 |
+
'role': 'user',
|
| 82 |
+
'content': [
|
| 83 |
+
{'type': 'text', 'text': f"Observation:{warn_obs} please analyze the attached screenshot and give the Thought and Action. I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"},
|
| 84 |
+
{
|
| 85 |
+
'type': 'image_url',
|
| 86 |
+
'image_url': {"url": f"data:image/png;base64,{web_img_b64}"}
|
| 87 |
+
}
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
else:
|
| 91 |
+
curr_msg = {
|
| 92 |
+
'role': 'user',
|
| 93 |
+
'content': [
|
| 94 |
+
{'type': 'text', 'text': f"Observation: {pdf_obs} Please analyze the response given by Assistant, then consider whether to continue iterating or not. The screenshot of the current page is also attached, give the Thought and Action. I've provided the tag name of each element and the text it contains (if text exists). Note that <textarea> or <input> may be textbox, but not exactly. Please focus more on the screenshot and then refer to the textual information.\n{web_text}"},
|
| 95 |
+
{
|
| 96 |
+
'type': 'image_url',
|
| 97 |
+
'image_url': {"url": f"data:image/png;base64,{web_img_b64}"}
|
| 98 |
+
}
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
return curr_msg
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def format_msg_text_only(it, init_msg, pdf_obs, warn_obs, ac_tree):
|
| 105 |
+
if it == 1:
|
| 106 |
+
init_msg_format = {
|
| 107 |
+
'role': 'user',
|
| 108 |
+
'content': init_msg + '\n' + ac_tree
|
| 109 |
+
}
|
| 110 |
+
return init_msg_format
|
| 111 |
+
else:
|
| 112 |
+
if not pdf_obs:
|
| 113 |
+
curr_msg = {
|
| 114 |
+
'role': 'user',
|
| 115 |
+
'content': f"Observation:{warn_obs} please analyze the accessibility tree and give the Thought and Action.\n{ac_tree}"
|
| 116 |
+
}
|
| 117 |
+
else:
|
| 118 |
+
curr_msg = {
|
| 119 |
+
'role': 'user',
|
| 120 |
+
'content': f"Observation: {pdf_obs} Please analyze the response given by Assistant, then consider whether to continue iterating or not. The accessibility tree of the current page is also given, give the Thought and Action.\n{ac_tree}"
|
| 121 |
+
}
|
| 122 |
+
return curr_msg
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def call_gpt4v_api(args, openai_client, messages):
|
| 126 |
+
retry_times = 0
|
| 127 |
+
while True:
|
| 128 |
+
try:
|
| 129 |
+
if not args.text_only:
|
| 130 |
+
logging.info('Calling gpt4v API...')
|
| 131 |
+
openai_response = openai_client.chat.completions.create(
|
| 132 |
+
model=args.api_model, messages=messages, max_tokens=1000, seed=args.seed
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
logging.info('Calling gpt4 API...')
|
| 136 |
+
openai_response = openai_client.chat.completions.create(
|
| 137 |
+
model=args.api_model, messages=messages, max_tokens=1000, seed=args.seed, timeout=30
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
prompt_tokens = openai_response.usage.prompt_tokens
|
| 141 |
+
completion_tokens = openai_response.usage.completion_tokens
|
| 142 |
+
|
| 143 |
+
logging.info(f'Prompt Tokens: {prompt_tokens}; Completion Tokens: {completion_tokens}')
|
| 144 |
+
|
| 145 |
+
gpt_call_error = False
|
| 146 |
+
return prompt_tokens, completion_tokens, gpt_call_error, openai_response
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logging.info(f'Error occurred, retrying. Error type: {type(e).__name__}')
|
| 150 |
+
|
| 151 |
+
if type(e).__name__ == 'RateLimitError':
|
| 152 |
+
time.sleep(10)
|
| 153 |
+
|
| 154 |
+
elif type(e).__name__ == 'APIError':
|
| 155 |
+
time.sleep(15)
|
| 156 |
+
|
| 157 |
+
elif type(e).__name__ == 'InvalidRequestError':
|
| 158 |
+
gpt_call_error = True
|
| 159 |
+
return None, None, gpt_call_error, None
|
| 160 |
+
|
| 161 |
+
else:
|
| 162 |
+
gpt_call_error = True
|
| 163 |
+
return None, None, gpt_call_error, None
|
| 164 |
+
|
| 165 |
+
retry_times += 1
|
| 166 |
+
if retry_times == 10:
|
| 167 |
+
logging.info('Retrying too many times')
|
| 168 |
+
return None, None, True, None
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def exec_action_click(info, web_ele, driver_task):
|
| 172 |
+
driver_task.execute_script("arguments[0].setAttribute('target', '_self')", web_ele)
|
| 173 |
+
web_ele.click()
|
| 174 |
+
time.sleep(3)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def exec_action_type(info, web_ele, driver_task):
|
| 178 |
+
warn_obs = ""
|
| 179 |
+
type_content = info['content']
|
| 180 |
+
|
| 181 |
+
ele_tag_name = web_ele.tag_name.lower()
|
| 182 |
+
ele_type = web_ele.get_attribute("type")
|
| 183 |
+
# outer_html = web_ele.get_attribute("outerHTML")
|
| 184 |
+
if (ele_tag_name != 'input' and ele_tag_name != 'textarea') or (ele_tag_name == 'input' and ele_type not in ['text', 'search', 'password', 'email', 'tel']):
|
| 185 |
+
warn_obs = f"note: The web element you're trying to type may not be a textbox, and its tag name is <{web_ele.tag_name}>, type is {ele_type}."
|
| 186 |
+
try:
|
| 187 |
+
# Not always work to delete
|
| 188 |
+
web_ele.clear()
|
| 189 |
+
# Another way to delete
|
| 190 |
+
if platform.system() == 'Darwin':
|
| 191 |
+
web_ele.send_keys(Keys.COMMAND + "a")
|
| 192 |
+
else:
|
| 193 |
+
web_ele.send_keys(Keys.CONTROL + "a")
|
| 194 |
+
web_ele.send_keys(" ")
|
| 195 |
+
web_ele.send_keys(Keys.BACKSPACE)
|
| 196 |
+
except:
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
actions = ActionChains(driver_task)
|
| 200 |
+
actions.click(web_ele).perform()
|
| 201 |
+
actions.pause(1)
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
driver_task.execute_script("""window.onkeydown = function(e) {if(e.keyCode == 32 && e.target.type != 'text' && e.target.type != 'textarea' && e.target.type != 'search') {e.preventDefault();}};""")
|
| 205 |
+
except:
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
actions.send_keys(type_content)
|
| 209 |
+
actions.pause(2)
|
| 210 |
+
|
| 211 |
+
actions.send_keys(Keys.ENTER)
|
| 212 |
+
actions.perform()
|
| 213 |
+
time.sleep(10)
|
| 214 |
+
return warn_obs
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def exec_action_scroll(info, web_eles, driver_task, args, obs_info):
|
| 218 |
+
scroll_ele_number = info['number']
|
| 219 |
+
scroll_content = info['content']
|
| 220 |
+
if scroll_ele_number == "WINDOW":
|
| 221 |
+
if scroll_content == 'down':
|
| 222 |
+
driver_task.execute_script(f"window.scrollBy(0, {args.window_height*2//3});")
|
| 223 |
+
else:
|
| 224 |
+
driver_task.execute_script(f"window.scrollBy(0, {-args.window_height*2//3});")
|
| 225 |
+
else:
|
| 226 |
+
if not args.text_only:
|
| 227 |
+
scroll_ele_number = int(scroll_ele_number)
|
| 228 |
+
web_ele = web_eles[scroll_ele_number]
|
| 229 |
+
else:
|
| 230 |
+
element_box = obs_info[scroll_ele_number]['union_bound']
|
| 231 |
+
element_box_center = (element_box[0] + element_box[2] // 2, element_box[1] + element_box[3] // 2)
|
| 232 |
+
web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
|
| 233 |
+
actions = ActionChains(driver_task)
|
| 234 |
+
driver_task.execute_script("arguments[0].focus();", web_ele)
|
| 235 |
+
if scroll_content == 'down':
|
| 236 |
+
actions.key_down(Keys.ALT).send_keys(Keys.ARROW_DOWN).key_up(Keys.ALT).perform()
|
| 237 |
+
else:
|
| 238 |
+
actions.key_down(Keys.ALT).send_keys(Keys.ARROW_UP).key_up(Keys.ALT).perform()
|
| 239 |
+
time.sleep(3)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def webvoyager_run(args, task, task_dir):
|
| 243 |
+
# OpenAI client
|
| 244 |
+
client = OpenAI(api_key=args.api_key)
|
| 245 |
+
options = driver_config(args)
|
| 246 |
+
setup_logger(task_dir)
|
| 247 |
+
logging.info(f'########## TASK{task["id"]} ##########')
|
| 248 |
+
|
| 249 |
+
driver_task = webdriver.Chrome(options=options)
|
| 250 |
+
driver_task.set_window_size(args.window_width, args.window_height)
|
| 251 |
+
driver_task.get(task['web'])
|
| 252 |
+
try:
|
| 253 |
+
driver_task.find_element(By.TAG_NAME, 'body').click()
|
| 254 |
+
except:
|
| 255 |
+
pass
|
| 256 |
+
driver_task.execute_script("""window.onkeydown = function(e) {if(e.keyCode == 32 && e.target.type != 'text' && e.target.type != 'textarea') {e.preventDefault();}};""")
|
| 257 |
+
time.sleep(5)
|
| 258 |
+
|
| 259 |
+
for filename in os.listdir(args.download_dir):
|
| 260 |
+
file_path = os.path.join(args.download_dir, filename)
|
| 261 |
+
if os.path.isfile(file_path):
|
| 262 |
+
os.remove(file_path)
|
| 263 |
+
|
| 264 |
+
download_files = []
|
| 265 |
+
fail_obs = ""
|
| 266 |
+
pdf_obs = ""
|
| 267 |
+
warn_obs = ""
|
| 268 |
+
pattern = r'Thought:|Action:|Observation:'
|
| 269 |
+
|
| 270 |
+
messages = [{'role': 'system', 'content': SYSTEM_PROMPT}]
|
| 271 |
+
obs_prompt = "Observation: please analyze the attached screenshot and give the Thought and Action. "
|
| 272 |
+
if args.text_only:
|
| 273 |
+
messages = [{'role': 'system', 'content': SYSTEM_PROMPT_TEXT_ONLY}]
|
| 274 |
+
obs_prompt = "Observation: please analyze the accessibility tree and give the Thought and Action."
|
| 275 |
+
|
| 276 |
+
init_msg = f"""Now given a task: {task['ques']} Please interact with https://www.example.com and get the answer. \n"""
|
| 277 |
+
init_msg = init_msg.replace('https://www.example.com', task['web'])
|
| 278 |
+
init_msg = init_msg + obs_prompt
|
| 279 |
+
|
| 280 |
+
it = 0
|
| 281 |
+
accumulate_prompt_token = 0
|
| 282 |
+
accumulate_completion_token = 0
|
| 283 |
+
|
| 284 |
+
while it < args.max_iter:
|
| 285 |
+
logging.info(f'Iter: {it}')
|
| 286 |
+
it += 1
|
| 287 |
+
if not fail_obs:
|
| 288 |
+
try:
|
| 289 |
+
if not args.text_only:
|
| 290 |
+
rects, web_eles, web_eles_text = get_web_element_rect(driver_task, fix_color=args.fix_box_color)
|
| 291 |
+
else:
|
| 292 |
+
accessibility_tree_path = os.path.join(task_dir, 'accessibility_tree{}'.format(it))
|
| 293 |
+
ac_tree, obs_info = get_webarena_accessibility_tree(driver_task, accessibility_tree_path)
|
| 294 |
+
except Exception as e:
|
| 295 |
+
if not args.text_only:
|
| 296 |
+
logging.error('Driver error when adding set-of-mark.')
|
| 297 |
+
else:
|
| 298 |
+
logging.error('Driver error when obtaining accessibility tree.')
|
| 299 |
+
logging.error(e)
|
| 300 |
+
break
|
| 301 |
+
|
| 302 |
+
img_path = os.path.join(task_dir, 'screenshot{}.png'.format(it))
|
| 303 |
+
driver_task.save_screenshot(img_path)
|
| 304 |
+
|
| 305 |
+
if (not args.text_only) and args.save_accessibility_tree:
|
| 306 |
+
accessibility_tree_path = os.path.join(task_dir, 'accessibility_tree{}'.format(it))
|
| 307 |
+
get_webarena_accessibility_tree(driver_task, accessibility_tree_path)
|
| 308 |
+
|
| 309 |
+
b64_img = encode_image(img_path)
|
| 310 |
+
|
| 311 |
+
if not args.text_only:
|
| 312 |
+
curr_msg = format_msg(it, init_msg, pdf_obs, warn_obs, b64_img, web_eles_text)
|
| 313 |
+
else:
|
| 314 |
+
curr_msg = format_msg_text_only(it, init_msg, pdf_obs, warn_obs, ac_tree)
|
| 315 |
+
messages.append(curr_msg)
|
| 316 |
+
else:
|
| 317 |
+
curr_msg = {
|
| 318 |
+
'role': 'user',
|
| 319 |
+
'content': fail_obs
|
| 320 |
+
}
|
| 321 |
+
messages.append(curr_msg)
|
| 322 |
+
|
| 323 |
+
if not args.text_only:
|
| 324 |
+
messages = clip_message_and_obs(messages, args.max_attached_imgs)
|
| 325 |
+
else:
|
| 326 |
+
messages = clip_message_and_obs_text_only(messages, args.max_attached_imgs)
|
| 327 |
+
|
| 328 |
+
prompt_tokens, completion_tokens, gpt_call_error, openai_response = call_gpt4v_api(args, client, messages)
|
| 329 |
+
|
| 330 |
+
if gpt_call_error:
|
| 331 |
+
break
|
| 332 |
+
else:
|
| 333 |
+
accumulate_prompt_token += prompt_tokens
|
| 334 |
+
accumulate_completion_token += completion_tokens
|
| 335 |
+
logging.info(f'Accumulate Prompt Tokens: {accumulate_prompt_token}; Accumulate Completion Tokens: {accumulate_completion_token}')
|
| 336 |
+
logging.info('API call complete...')
|
| 337 |
+
gpt_4v_res = openai_response.choices[0].message.content
|
| 338 |
+
messages.append({'role': 'assistant', 'content': gpt_4v_res})
|
| 339 |
+
|
| 340 |
+
if (not args.text_only) and 'rects' in locals() and rects:
|
| 341 |
+
logging.info(f"Num of interactive elements: {len(rects)}")
|
| 342 |
+
for rect_ele in rects:
|
| 343 |
+
driver_task.execute_script("arguments[0].remove()", rect_ele)
|
| 344 |
+
rects = []
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
assert 'Thought:' in gpt_4v_res and 'Action:' in gpt_4v_res
|
| 348 |
+
except AssertionError as e:
|
| 349 |
+
logging.error(e)
|
| 350 |
+
fail_obs = "Format ERROR: Both 'Thought' and 'Action' should be included in your reply."
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
chosen_action = re.split(pattern, gpt_4v_res)[2].strip()
|
| 354 |
+
action_key, info = extract_information(chosen_action)
|
| 355 |
+
|
| 356 |
+
fail_obs = ""
|
| 357 |
+
pdf_obs = ""
|
| 358 |
+
warn_obs = ""
|
| 359 |
+
try:
|
| 360 |
+
window_handle_task = driver_task.current_window_handle
|
| 361 |
+
driver_task.switch_to.window(window_handle_task)
|
| 362 |
+
|
| 363 |
+
if action_key == 'click':
|
| 364 |
+
if not args.text_only:
|
| 365 |
+
click_ele_number = int(info[0])
|
| 366 |
+
web_ele = web_eles[click_ele_number]
|
| 367 |
+
else:
|
| 368 |
+
click_ele_number = info[0]
|
| 369 |
+
element_box = obs_info[click_ele_number]['union_bound']
|
| 370 |
+
element_box_center = (element_box[0] + element_box[2] // 2,
|
| 371 |
+
element_box[1] + element_box[3] // 2)
|
| 372 |
+
web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
|
| 373 |
+
|
| 374 |
+
ele_tag_name = web_ele.tag_name.lower()
|
| 375 |
+
ele_type = web_ele.get_attribute("type")
|
| 376 |
+
|
| 377 |
+
exec_action_click(info, web_ele, driver_task)
|
| 378 |
+
|
| 379 |
+
current_files = sorted(os.listdir(args.download_dir))
|
| 380 |
+
if current_files != download_files:
|
| 381 |
+
time.sleep(10)
|
| 382 |
+
current_files = sorted(os.listdir(args.download_dir))
|
| 383 |
+
current_download_file = [pdf_file for pdf_file in current_files if pdf_file not in download_files and pdf_file.endswith('.pdf')]
|
| 384 |
+
if current_download_file:
|
| 385 |
+
pdf_file = current_download_file[0]
|
| 386 |
+
pdf_obs = get_pdf_retrieval_ans_from_assistant(client, os.path.join(args.download_dir, pdf_file), task['ques'])
|
| 387 |
+
shutil.copy(os.path.join(args.download_dir, pdf_file), task_dir)
|
| 388 |
+
pdf_obs = "You downloaded a PDF file, I ask the Assistant API to answer the task based on the PDF file and get the following response: " + pdf_obs
|
| 389 |
+
download_files = current_files
|
| 390 |
+
|
| 391 |
+
if ele_tag_name == 'button' and ele_type == 'submit':
|
| 392 |
+
time.sleep(10)
|
| 393 |
+
|
| 394 |
+
elif action_key == 'wait':
|
| 395 |
+
time.sleep(5)
|
| 396 |
+
|
| 397 |
+
elif action_key == 'type':
|
| 398 |
+
if not args.text_only:
|
| 399 |
+
type_ele_number = int(info['number'])
|
| 400 |
+
web_ele = web_eles[type_ele_number]
|
| 401 |
+
else:
|
| 402 |
+
type_ele_number = info['number']
|
| 403 |
+
element_box = obs_info[type_ele_number]['union_bound']
|
| 404 |
+
element_box_center = (element_box[0] + element_box[2] // 2,
|
| 405 |
+
element_box[1] + element_box[3] // 2)
|
| 406 |
+
web_ele = driver_task.execute_script("return document.elementFromPoint(arguments[0], arguments[1]);", element_box_center[0], element_box_center[1])
|
| 407 |
+
warn_obs = exec_action_type(info, web_ele, driver_task)
|
| 408 |
+
if 'wolfram' in task['web']:
|
| 409 |
+
time.sleep(5)
|
| 410 |
+
|
| 411 |
+
elif action_key == 'scroll':
|
| 412 |
+
if not args.text_only:
|
| 413 |
+
exec_action_scroll(info, web_eles, driver_task, args, None)
|
| 414 |
+
else:
|
| 415 |
+
exec_action_scroll(info, None, driver_task, args, obs_info)
|
| 416 |
+
|
| 417 |
+
elif action_key == 'goback':
|
| 418 |
+
driver_task.back()
|
| 419 |
+
time.sleep(2)
|
| 420 |
+
|
| 421 |
+
elif action_key == 'google':
|
| 422 |
+
driver_task.get('https://www.google.com/')
|
| 423 |
+
time.sleep(2)
|
| 424 |
+
|
| 425 |
+
elif action_key == 'answer':
|
| 426 |
+
logging.info(info['content'])
|
| 427 |
+
logging.info('finish!!')
|
| 428 |
+
break
|
| 429 |
+
else:
|
| 430 |
+
raise NotImplementedError
|
| 431 |
+
fail_obs = ""
|
| 432 |
+
except Exception as e:
|
| 433 |
+
logging.error('driver error info:')
|
| 434 |
+
logging.error(e)
|
| 435 |
+
if 'element click intercepted' not in str(e):
|
| 436 |
+
fail_obs = "The action you have chosen cannot be exected. Please double-check if you have selected the wrong Numerical Label or Action or Action format. Then provide the revised Thought and Action."
|
| 437 |
+
else:
|
| 438 |
+
fail_obs = ""
|
| 439 |
+
time.sleep(2)
|
| 440 |
+
|
| 441 |
+
print_message(messages, task_dir)
|
| 442 |
+
driver_task.quit()
|
| 443 |
+
logging.info(f'Total cost: {accumulate_prompt_token / 1000 * 0.01 + accumulate_completion_token / 1000 * 0.03}')
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def main():
|
| 447 |
+
parser = argparse.ArgumentParser()
|
| 448 |
+
parser.add_argument('--test_file', type=str, default='data/test.json')
|
| 449 |
+
parser.add_argument('--max_iter', type=int, default=5)
|
| 450 |
+
parser.add_argument("--api_key", default="key", type=str, help="YOUR_OPENAI_API_KEY")
|
| 451 |
+
parser.add_argument("--api_model", default="gpt-4-vision-preview", type=str, help="api model name")
|
| 452 |
+
parser.add_argument("--output_dir", type=str, default='results')
|
| 453 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 454 |
+
parser.add_argument("--max_attached_imgs", type=int, default=1)
|
| 455 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 456 |
+
parser.add_argument("--download_dir", type=str, default="downloads")
|
| 457 |
+
parser.add_argument("--text_only", action='store_true')
|
| 458 |
+
# for web browser
|
| 459 |
+
parser.add_argument("--headless", action='store_true', help='The window of selenium')
|
| 460 |
+
parser.add_argument("--save_accessibility_tree", action='store_true')
|
| 461 |
+
parser.add_argument("--force_device_scale", action='store_true')
|
| 462 |
+
parser.add_argument("--window_width", type=int, default=1024)
|
| 463 |
+
parser.add_argument("--window_height", type=int, default=768)
|
| 464 |
+
parser.add_argument("--fix_box_color", action='store_true')
|
| 465 |
+
|
| 466 |
+
args = parser.parse_args()
|
| 467 |
+
|
| 468 |
+
# Save Result file
|
| 469 |
+
current_time = time.strftime("%Y%m%d_%H_%M_%S", time.localtime())
|
| 470 |
+
result_dir = os.path.join(args.output_dir, current_time)
|
| 471 |
+
os.makedirs(result_dir, exist_ok=True)
|
| 472 |
+
|
| 473 |
+
# Load tasks
|
| 474 |
+
tasks = []
|
| 475 |
+
with open(args.test_file, 'r', encoding='utf-8') as f:
|
| 476 |
+
for line in f:
|
| 477 |
+
tasks.append(json.loads(line))
|
| 478 |
+
|
| 479 |
+
for task_id in range(len(tasks)):
|
| 480 |
+
task = tasks[task_id]
|
| 481 |
+
task_dir = os.path.join(result_dir, 'task{}'.format(task["id"]))
|
| 482 |
+
os.makedirs(task_dir, exist_ok=True)
|
| 483 |
+
webvoyager_run(args, task, task_dir)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
if __name__ == '__main__':
|
| 487 |
+
main()
|
| 488 |
+
print('End of process')
|
utils.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from utils_webarena import fetch_browser_info, fetch_page_accessibility_tree,\
|
| 10 |
+
parse_accessibility_tree, clean_accesibility_tree
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resize_image(image_path):
|
| 14 |
+
image = Image.open(image_path)
|
| 15 |
+
width, height = image.size
|
| 16 |
+
|
| 17 |
+
if min(width, height) < 512:
|
| 18 |
+
return image
|
| 19 |
+
elif width < height:
|
| 20 |
+
new_width = 512
|
| 21 |
+
new_height = int(height * (new_width / width))
|
| 22 |
+
else:
|
| 23 |
+
new_height = 512
|
| 24 |
+
new_width = int(width * (new_height / height))
|
| 25 |
+
|
| 26 |
+
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 27 |
+
resized_image.save(image_path)
|
| 28 |
+
# return resized_image
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# base64 encoding
|
| 32 |
+
# Code from OpenAI Document
|
| 33 |
+
def encode_image(image_path):
|
| 34 |
+
with open(image_path, "rb") as image_file:
|
| 35 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# interact with webpage and add rectangles on elements
|
| 39 |
+
def get_web_element_rect(browser, fix_color=True):
|
| 40 |
+
if fix_color:
|
| 41 |
+
selected_function = "getFixedColor"
|
| 42 |
+
# color_you_like = '#5210da'
|
| 43 |
+
else:
|
| 44 |
+
selected_function = "getRandomColor"
|
| 45 |
+
|
| 46 |
+
js_script = """
|
| 47 |
+
let labels = [];
|
| 48 |
+
|
| 49 |
+
function markPage() {
|
| 50 |
+
var bodyRect = document.body.getBoundingClientRect();
|
| 51 |
+
|
| 52 |
+
var items = Array.prototype.slice.call(
|
| 53 |
+
document.querySelectorAll('*')
|
| 54 |
+
).map(function(element) {
|
| 55 |
+
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
| 56 |
+
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
| 57 |
+
|
| 58 |
+
var rects = [...element.getClientRects()].filter(bb => {
|
| 59 |
+
var center_x = bb.left + bb.width / 2;
|
| 60 |
+
var center_y = bb.top + bb.height / 2;
|
| 61 |
+
var elAtCenter = document.elementFromPoint(center_x, center_y);
|
| 62 |
+
|
| 63 |
+
return elAtCenter === element || element.contains(elAtCenter)
|
| 64 |
+
}).map(bb => {
|
| 65 |
+
const rect = {
|
| 66 |
+
left: Math.max(0, bb.left),
|
| 67 |
+
top: Math.max(0, bb.top),
|
| 68 |
+
right: Math.min(vw, bb.right),
|
| 69 |
+
bottom: Math.min(vh, bb.bottom)
|
| 70 |
+
};
|
| 71 |
+
return {
|
| 72 |
+
...rect,
|
| 73 |
+
width: rect.right - rect.left,
|
| 74 |
+
height: rect.bottom - rect.top
|
| 75 |
+
}
|
| 76 |
+
});
|
| 77 |
+
|
| 78 |
+
var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
element: element,
|
| 82 |
+
include:
|
| 83 |
+
(element.tagName === "INPUT" || element.tagName === "TEXTAREA" || element.tagName === "SELECT") ||
|
| 84 |
+
(element.tagName === "BUTTON" || element.tagName === "A" || (element.onclick != null) || window.getComputedStyle(element).cursor == "pointer") ||
|
| 85 |
+
(element.tagName === "IFRAME" || element.tagName === "VIDEO" || element.tagName === "LI" || element.tagName === "TD" || element.tagName === "OPTION")
|
| 86 |
+
,
|
| 87 |
+
area,
|
| 88 |
+
rects,
|
| 89 |
+
text: element.textContent.trim().replace(/\s{2,}/g, ' ')
|
| 90 |
+
};
|
| 91 |
+
}).filter(item =>
|
| 92 |
+
item.include && (item.area >= 20)
|
| 93 |
+
);
|
| 94 |
+
|
| 95 |
+
// Only keep inner clickable items
|
| 96 |
+
// first delete button inner clickable items
|
| 97 |
+
const buttons = Array.from(document.querySelectorAll('button, a, input[type="button"], div[role="button"]'));
|
| 98 |
+
|
| 99 |
+
//items = items.filter(x => !buttons.some(y => y.contains(x.element) && !(x.element === y) ));
|
| 100 |
+
items = items.filter(x => !buttons.some(y => items.some(z => z.element === y) && y.contains(x.element) && !(x.element === y) ));
|
| 101 |
+
items = items.filter(x =>
|
| 102 |
+
!(x.element.parentNode &&
|
| 103 |
+
x.element.parentNode.tagName === 'SPAN' &&
|
| 104 |
+
x.element.parentNode.children.length === 1 &&
|
| 105 |
+
x.element.parentNode.getAttribute('role') &&
|
| 106 |
+
items.some(y => y.element === x.element.parentNode)));
|
| 107 |
+
|
| 108 |
+
items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
|
| 109 |
+
|
| 110 |
+
// Function to generate random colors
|
| 111 |
+
function getRandomColor(index) {
|
| 112 |
+
var letters = '0123456789ABCDEF';
|
| 113 |
+
var color = '#';
|
| 114 |
+
for (var i = 0; i < 6; i++) {
|
| 115 |
+
color += letters[Math.floor(Math.random() * 16)];
|
| 116 |
+
}
|
| 117 |
+
return color;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
function getFixedColor(index) {
|
| 121 |
+
var color = '#000000'
|
| 122 |
+
return color
|
| 123 |
+
}
|
| 124 |
+
//function getFixedColor(index){
|
| 125 |
+
// var colors = ['#FF0000', '#00FF00', '#0000FF', '#000000']; // Red, Green, Blue, Black
|
| 126 |
+
// return colors[index % 4];
|
| 127 |
+
//}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
// Lets create a floating border on top of these elements that will always be visible
|
| 131 |
+
items.forEach(function(item, index) {
|
| 132 |
+
item.rects.forEach((bbox) => {
|
| 133 |
+
newElement = document.createElement("div");
|
| 134 |
+
var borderColor = COLOR_FUNCTION(index);
|
| 135 |
+
newElement.style.outline = `2px dashed ${borderColor}`;
|
| 136 |
+
newElement.style.position = "fixed";
|
| 137 |
+
newElement.style.left = bbox.left + "px";
|
| 138 |
+
newElement.style.top = bbox.top + "px";
|
| 139 |
+
newElement.style.width = bbox.width + "px";
|
| 140 |
+
newElement.style.height = bbox.height + "px";
|
| 141 |
+
newElement.style.pointerEvents = "none";
|
| 142 |
+
newElement.style.boxSizing = "border-box";
|
| 143 |
+
newElement.style.zIndex = 2147483647;
|
| 144 |
+
// newElement.style.background = `${borderColor}80`;
|
| 145 |
+
|
| 146 |
+
// Add floating label at the corner
|
| 147 |
+
var label = document.createElement("span");
|
| 148 |
+
label.textContent = index;
|
| 149 |
+
label.style.position = "absolute";
|
| 150 |
+
//label.style.top = "-19px";
|
| 151 |
+
label.style.top = Math.max(-19, -bbox.top) + "px";
|
| 152 |
+
//label.style.left = "0px";
|
| 153 |
+
label.style.left = Math.min(Math.floor(bbox.width / 5), 2) + "px";
|
| 154 |
+
label.style.background = borderColor;
|
| 155 |
+
label.style.color = "white";
|
| 156 |
+
label.style.padding = "2px 4px";
|
| 157 |
+
label.style.fontSize = "12px";
|
| 158 |
+
label.style.borderRadius = "2px";
|
| 159 |
+
newElement.appendChild(label);
|
| 160 |
+
|
| 161 |
+
document.body.appendChild(newElement);
|
| 162 |
+
labels.push(newElement);
|
| 163 |
+
// item.element.setAttribute("-ai-label", label.textContent);
|
| 164 |
+
});
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
// For the first way
|
| 168 |
+
// return [labels, items.map(item => ({
|
| 169 |
+
// rect: item.rects[0] // assuming there's at least one rect
|
| 170 |
+
// }))];
|
| 171 |
+
|
| 172 |
+
// For the second way
|
| 173 |
+
return [labels, items]
|
| 174 |
+
}
|
| 175 |
+
return markPage();""".replace("COLOR_FUNCTION", selected_function)
|
| 176 |
+
rects, items_raw = browser.execute_script(js_script)
|
| 177 |
+
|
| 178 |
+
# format_ele_text = [f"[{web_ele_id}]: \"{items_raw[web_ele_id]['text']}\";" for web_ele_id in range(len(items_raw)) if items_raw[web_ele_id]['text'] ]
|
| 179 |
+
format_ele_text = []
|
| 180 |
+
for web_ele_id in range(len(items_raw)):
|
| 181 |
+
label_text = items_raw[web_ele_id]['text']
|
| 182 |
+
ele_tag_name = items_raw[web_ele_id]['element'].tag_name
|
| 183 |
+
ele_type = items_raw[web_ele_id]['element'].get_attribute("type")
|
| 184 |
+
ele_aria_label = items_raw[web_ele_id]['element'].get_attribute("aria-label")
|
| 185 |
+
input_attr_types = ['text', 'search', 'password', 'email', 'tel']
|
| 186 |
+
|
| 187 |
+
if not label_text:
|
| 188 |
+
if (ele_tag_name.lower() == 'input' and ele_type in input_attr_types) or ele_tag_name.lower() == 'textarea' or (ele_tag_name.lower() == 'button' and ele_type in ['submit', 'button']):
|
| 189 |
+
if ele_aria_label:
|
| 190 |
+
format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{ele_aria_label}\";")
|
| 191 |
+
else:
|
| 192 |
+
format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\";" )
|
| 193 |
+
|
| 194 |
+
elif label_text and len(label_text) < 200:
|
| 195 |
+
if not ("<img" in label_text and "src=" in label_text):
|
| 196 |
+
if ele_tag_name in ["button", "input", "textarea"]:
|
| 197 |
+
if ele_aria_label and (ele_aria_label != label_text):
|
| 198 |
+
format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\", \"{ele_aria_label}\";")
|
| 199 |
+
else:
|
| 200 |
+
format_ele_text.append(f"[{web_ele_id}]: <{ele_tag_name}> \"{label_text}\";")
|
| 201 |
+
else:
|
| 202 |
+
if ele_aria_label and (ele_aria_label != label_text):
|
| 203 |
+
format_ele_text.append(f"[{web_ele_id}]: \"{label_text}\", \"{ele_aria_label}\";")
|
| 204 |
+
else:
|
| 205 |
+
format_ele_text.append(f"[{web_ele_id}]: \"{label_text}\";")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
format_ele_text = '\t'.join(format_ele_text)
|
| 210 |
+
return rects, [web_ele['element'] for web_ele in items_raw], format_ele_text
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def extract_information(text):
|
| 214 |
+
patterns = {
|
| 215 |
+
"click": r"Click \[?(\d+)\]?",
|
| 216 |
+
"type": r"Type \[?(\d+)\]?[; ]+\[?(.[^\]]*)\]?",
|
| 217 |
+
# "delete_and_type": r"Delete_and_Type \[?(\d+)\]?[; ]+\[?(.[^\]]*)\]?",
|
| 218 |
+
"scroll": r"Scroll \[?(\d+|WINDOW)\]?[; ]+\[?(up|down)\]?",
|
| 219 |
+
"wait": r"^Wait",
|
| 220 |
+
"goback": r"^GoBack",
|
| 221 |
+
"google": r"^Google",
|
| 222 |
+
"answer": r"ANSWER[; ]+\[?(.[^\]]*)\]?"
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
for key, pattern in patterns.items():
|
| 226 |
+
match = re.search(pattern, text)
|
| 227 |
+
if match:
|
| 228 |
+
if key in ["click", "wait", "goback", "google"]:
|
| 229 |
+
# no content
|
| 230 |
+
return key, match.groups()
|
| 231 |
+
else:
|
| 232 |
+
return key, {"number": match.group(1), "content": match.group(2)} if key in ["type", "scroll"] else {"content": match.group(1)}
|
| 233 |
+
return None, None
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def clip_message(msg, max_img_num):
|
| 237 |
+
clipped_msg = []
|
| 238 |
+
img_num = 0
|
| 239 |
+
for idx in range(len(msg)):
|
| 240 |
+
curr_msg = msg[len(msg) - 1 - idx]
|
| 241 |
+
if curr_msg['role'] != 'user':
|
| 242 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 243 |
+
else:
|
| 244 |
+
if type(curr_msg['content']) == str:
|
| 245 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 246 |
+
elif img_num < max_img_num:
|
| 247 |
+
img_num += 1
|
| 248 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 249 |
+
else:
|
| 250 |
+
curr_msg_clip = {
|
| 251 |
+
'role': curr_msg['role'],
|
| 252 |
+
'content': curr_msg['content'][0]["text"]
|
| 253 |
+
}
|
| 254 |
+
clipped_msg = [curr_msg_clip] + clipped_msg
|
| 255 |
+
return clipped_msg
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def clip_message_and_obs(msg, max_img_num):
|
| 259 |
+
clipped_msg = []
|
| 260 |
+
img_num = 0
|
| 261 |
+
for idx in range(len(msg)):
|
| 262 |
+
curr_msg = msg[len(msg) - 1 - idx]
|
| 263 |
+
if curr_msg['role'] != 'user':
|
| 264 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 265 |
+
else:
|
| 266 |
+
if type(curr_msg['content']) == str:
|
| 267 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 268 |
+
elif img_num < max_img_num:
|
| 269 |
+
img_num += 1
|
| 270 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 271 |
+
else:
|
| 272 |
+
msg_no_pdf = curr_msg['content'][0]["text"].split("Observation:")[0].strip() + "Observation: A screenshot and some texts. (Omitted in context.)"
|
| 273 |
+
msg_pdf = curr_msg['content'][0]["text"].split("Observation:")[0].strip() + "Observation: A screenshot, a PDF file and some texts. (Omitted in context.)"
|
| 274 |
+
curr_msg_clip = {
|
| 275 |
+
'role': curr_msg['role'],
|
| 276 |
+
'content': msg_no_pdf if "You downloaded a PDF file" not in curr_msg['content'][0]["text"] else msg_pdf
|
| 277 |
+
}
|
| 278 |
+
clipped_msg = [curr_msg_clip] + clipped_msg
|
| 279 |
+
return clipped_msg
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def clip_message_and_obs_text_only(msg, max_tree_num):
|
| 283 |
+
clipped_msg = []
|
| 284 |
+
tree_num = 0
|
| 285 |
+
for idx in range(len(msg)):
|
| 286 |
+
curr_msg = msg[len(msg) - 1 - idx]
|
| 287 |
+
if curr_msg['role'] != 'user':
|
| 288 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 289 |
+
else:
|
| 290 |
+
if tree_num < max_tree_num:
|
| 291 |
+
tree_num += 1
|
| 292 |
+
clipped_msg = [curr_msg] + clipped_msg
|
| 293 |
+
else:
|
| 294 |
+
msg_no_pdf = curr_msg['content'].split("Observation:")[0].strip() + "Observation: An accessibility tree. (Omitted in context.)"
|
| 295 |
+
msg_pdf = curr_msg['content'].split("Observation:")[0].strip() + "Observation: An accessibility tree and a PDF file. (Omitted in context.)"
|
| 296 |
+
curr_msg_clip = {
|
| 297 |
+
'role': curr_msg['role'],
|
| 298 |
+
'content': msg_no_pdf if "You downloaded a PDF file" not in curr_msg['content'] else msg_pdf
|
| 299 |
+
}
|
| 300 |
+
clipped_msg = [curr_msg_clip] + clipped_msg
|
| 301 |
+
return clipped_msg
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def print_message(json_object, save_dir=None):
|
| 305 |
+
remove_b64code_obj = []
|
| 306 |
+
for obj in json_object:
|
| 307 |
+
if obj['role'] != 'user':
|
| 308 |
+
# print(obj)
|
| 309 |
+
logging.info(obj)
|
| 310 |
+
remove_b64code_obj.append(obj)
|
| 311 |
+
else:
|
| 312 |
+
if type(obj['content']) == str:
|
| 313 |
+
# print(obj)
|
| 314 |
+
logging.info(obj)
|
| 315 |
+
remove_b64code_obj.append(obj)
|
| 316 |
+
else:
|
| 317 |
+
print_obj = {
|
| 318 |
+
'role': obj['role'],
|
| 319 |
+
'content': obj['content']
|
| 320 |
+
}
|
| 321 |
+
for item in print_obj['content']:
|
| 322 |
+
if item['type'] == 'image_url':
|
| 323 |
+
item['image_url'] = {"url": "data:image/png;base64,{b64_img}"}
|
| 324 |
+
# print(print_obj)
|
| 325 |
+
logging.info(print_obj)
|
| 326 |
+
remove_b64code_obj.append(print_obj)
|
| 327 |
+
if save_dir:
|
| 328 |
+
with open(os.path.join(save_dir, 'interact_messages.json'), 'w', encoding='utf-8') as fw:
|
| 329 |
+
json.dump(remove_b64code_obj, fw, indent=2)
|
| 330 |
+
# return remove_b64code_obj
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def get_webarena_accessibility_tree(browser, save_file=None):
|
| 334 |
+
browser_info = fetch_browser_info(browser)
|
| 335 |
+
accessibility_tree = fetch_page_accessibility_tree(browser_info, browser, current_viewport_only=True)
|
| 336 |
+
content, obs_nodes_info = parse_accessibility_tree(accessibility_tree)
|
| 337 |
+
content = clean_accesibility_tree(content)
|
| 338 |
+
if save_file:
|
| 339 |
+
with open(save_file + '.json', 'w', encoding='utf-8') as fw:
|
| 340 |
+
json.dump(obs_nodes_info, fw, indent=2)
|
| 341 |
+
with open(save_file + '.txt', 'w', encoding='utf-8') as fw:
|
| 342 |
+
fw.write(content)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
return content, obs_nodes_info
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def compare_images(img1_path, img2_path):
|
| 349 |
+
img1 = Image.open(img1_path)
|
| 350 |
+
img2 = Image.open(img2_path)
|
| 351 |
+
|
| 352 |
+
img1_array = np.asarray(img1)
|
| 353 |
+
img2_array = np.asarray(img2)
|
| 354 |
+
|
| 355 |
+
difference = np.abs(img1_array - img2_array)
|
| 356 |
+
|
| 357 |
+
total_difference = np.sum(difference)
|
| 358 |
+
|
| 359 |
+
return total_difference
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def get_pdf_retrieval_ans_from_assistant(client, pdf_path, task):
|
| 363 |
+
# print("You download a PDF file that will be retrieved using the Assistant API.")
|
| 364 |
+
logging.info("You download a PDF file that will be retrieved using the Assistant API.")
|
| 365 |
+
file = client.files.create(
|
| 366 |
+
file=open(pdf_path, "rb"),
|
| 367 |
+
purpose='assistants'
|
| 368 |
+
)
|
| 369 |
+
# print("Create assistant...")
|
| 370 |
+
logging.info("Create assistant...")
|
| 371 |
+
assistant = client.beta.assistants.create(
|
| 372 |
+
instructions="You are a helpful assistant that can analyze the content of a PDF file and give an answer that matches the given task, or retrieve relevant content that matches the task.",
|
| 373 |
+
model="gpt-4-1106-preview",
|
| 374 |
+
tools=[{"type": "retrieval"}],
|
| 375 |
+
file_ids=[file.id]
|
| 376 |
+
)
|
| 377 |
+
thread = client.beta.threads.create()
|
| 378 |
+
message = client.beta.threads.messages.create(
|
| 379 |
+
thread_id=thread.id,
|
| 380 |
+
role="user",
|
| 381 |
+
content=task,
|
| 382 |
+
file_ids=[file.id]
|
| 383 |
+
)
|
| 384 |
+
run = client.beta.threads.runs.create(
|
| 385 |
+
thread_id=thread.id,
|
| 386 |
+
assistant_id=assistant.id
|
| 387 |
+
)
|
| 388 |
+
while True:
|
| 389 |
+
# Retrieve the run status
|
| 390 |
+
run_status = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id)
|
| 391 |
+
if run_status.status == 'completed':
|
| 392 |
+
break
|
| 393 |
+
time.sleep(2)
|
| 394 |
+
messages = client.beta.threads.messages.list(thread_id=thread.id)
|
| 395 |
+
messages_text = messages.data[0].content[0].text.value
|
| 396 |
+
file_deletion_status = client.beta.assistants.files.delete(
|
| 397 |
+
assistant_id=assistant.id,
|
| 398 |
+
file_id=file.id
|
| 399 |
+
)
|
| 400 |
+
# print(file_deletion_status)
|
| 401 |
+
logging.info(file_deletion_status)
|
| 402 |
+
assistant_deletion_status = client.beta.assistants.delete(assistant.id)
|
| 403 |
+
# print(assistant_deletion_status)
|
| 404 |
+
logging.info(assistant_deletion_status)
|
| 405 |
+
return messages_text
|