harvesthealth commited on
Commit
120f02c
·
verified ·
1 Parent(s): 1bcb2af

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. utils_webarena.py +389 -0
utils_webarena.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, TypedDict
2
+ import re
3
+
4
+
5
+ class AccessibilityTreeNode(TypedDict):
6
+ nodeId: str
7
+ ignored: bool
8
+ role: dict[str, Any]
9
+ chromeRole: dict[str, Any]
10
+ name: dict[str, Any]
11
+ properties: list[dict[str, Any]]
12
+ childIds: list[str]
13
+ parentId: str
14
+ backendDOMNodeId: str
15
+ frameId: str
16
+ bound: list[float] | None
17
+ union_bound: list[float] | None
18
+ offsetrect_bound: list[float] | None
19
+
20
+
21
+ class BrowserConfig(TypedDict):
22
+ win_top_bound: float
23
+ win_left_bound: float
24
+ win_width: float
25
+ win_height: float
26
+ win_right_bound: float
27
+ win_lower_bound: float
28
+ device_pixel_ratio: float
29
+
30
+
31
+ class BrowserInfo(TypedDict):
32
+ DOMTree: dict[str, Any]
33
+ config: BrowserConfig
34
+
35
+ IGNORED_ACTREE_PROPERTIES = (
36
+ "focusable",
37
+ "editable",
38
+ "readonly",
39
+ "level",
40
+ "settable",
41
+ "multiline",
42
+ "invalid",
43
+ )
44
+
45
+ AccessibilityTree = list[AccessibilityTreeNode]
46
+
47
+ IN_VIEWPORT_RATIO_THRESHOLD = 0.6
48
+
49
+
50
+
51
+ def fetch_browser_info(
52
+ # page: Page,
53
+ browser,
54
+ ) -> BrowserInfo:
55
+ # extract domtree
56
+ tree = browser.execute_cdp_cmd(
57
+ "DOMSnapshot.captureSnapshot",
58
+ {
59
+ "computedStyles": [],
60
+ "includeDOMRects": True,
61
+ "includePaintOrder": True,
62
+ },
63
+ )
64
+
65
+ # calibrate the bounds, in some cases, the bounds are scaled somehow
66
+ bounds = tree["documents"][0]["layout"]["bounds"]
67
+ b = bounds[0]
68
+ n = b[2] / browser.get_window_size()["width"]
69
+ bounds = [[x / n for x in bound] for bound in bounds]
70
+ tree["documents"][0]["layout"]["bounds"] = bounds
71
+
72
+ # extract browser info
73
+ # win_top_bound = page.evaluate("window.pageYOffset")
74
+ # win_left_bound = page.evaluate("window.pageXOffset")
75
+ # win_width = page.evaluate("window.screen.width")
76
+ # win_height = page.evaluate("window.screen.height")
77
+
78
+ win_top_bound = browser.execute_script("return window.pageYOffset;")
79
+ win_left_bound = browser.execute_script("return window.pageXOffset;")
80
+ win_width = browser.execute_script("return window.screen.width;")
81
+ win_height = browser.execute_script("return window.screen.height;")
82
+ win_right_bound = win_left_bound + win_width
83
+ win_lower_bound = win_top_bound + win_height
84
+ device_pixel_ratio = browser.execute_script("return window.devicePixelRatio;")
85
+ assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
86
+
87
+ config: BrowserConfig = {
88
+ "win_top_bound": win_top_bound,
89
+ "win_left_bound": win_left_bound,
90
+ "win_width": win_width,
91
+ "win_height": win_height,
92
+ "win_right_bound": win_right_bound,
93
+ "win_lower_bound": win_lower_bound,
94
+ "device_pixel_ratio": device_pixel_ratio,
95
+ }
96
+
97
+ # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
98
+ info: BrowserInfo = {"DOMTree": tree, "config": config}
99
+
100
+ return info
101
+
102
+
103
+
104
+
105
+ def get_element_in_viewport_ratio(
106
+ elem_left_bound: float,
107
+ elem_top_bound: float,
108
+ width: float,
109
+ height: float,
110
+ config: BrowserConfig,
111
+ ) -> float:
112
+ elem_right_bound = elem_left_bound + width
113
+ elem_lower_bound = elem_top_bound + height
114
+
115
+ win_left_bound = 0
116
+ win_right_bound = config["win_width"]
117
+ win_top_bound = 0
118
+ win_lower_bound = config["win_height"]
119
+
120
+ # Compute the overlap in x and y axes
121
+ overlap_width = max(
122
+ 0,
123
+ min(elem_right_bound, win_right_bound)
124
+ - max(elem_left_bound, win_left_bound),
125
+ )
126
+ overlap_height = max(
127
+ 0,
128
+ min(elem_lower_bound, win_lower_bound)
129
+ - max(elem_top_bound, win_top_bound),
130
+ )
131
+
132
+ # Compute the overlap area
133
+ ratio = overlap_width * overlap_height / width * height
134
+ return ratio
135
+
136
+
137
+
138
+
139
+ def get_bounding_client_rect(
140
+ browser, backend_node_id: str
141
+ ) -> dict[str, Any]:
142
+ try:
143
+ remote_object = browser.execute_cdp_cmd(
144
+ "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
145
+ )
146
+ remote_object_id = remote_object["object"]["objectId"]
147
+ response = browser.execute_cdp_cmd(
148
+ "Runtime.callFunctionOn",
149
+ {
150
+ "objectId": remote_object_id,
151
+ "functionDeclaration": """
152
+ function() {
153
+ if (this.nodeType == 3) {
154
+ var range = document.createRange();
155
+ range.selectNode(this);
156
+ var rect = range.getBoundingClientRect().toJSON();
157
+ range.detach();
158
+ return rect;
159
+ } else {
160
+ return this.getBoundingClientRect().toJSON();
161
+ }
162
+ }
163
+ """,
164
+ "returnByValue": True,
165
+ },
166
+ )
167
+ return response
168
+ except:
169
+ return {"result": {"subtype": "error"}}
170
+
171
+
172
+ def fetch_page_accessibility_tree(
173
+ info: BrowserInfo,
174
+ browser,
175
+ # client: CDPSession,
176
+ current_viewport_only: bool,
177
+ ) -> AccessibilityTree:
178
+ accessibility_tree: AccessibilityTree = browser.execute_cdp_cmd(
179
+ "Accessibility.getFullAXTree", {}
180
+ )["nodes"]
181
+
182
+ # a few nodes are repeated in the accessibility tree
183
+ seen_ids = set()
184
+ _accessibility_tree = []
185
+ for node in accessibility_tree:
186
+ if node["nodeId"] not in seen_ids:
187
+ _accessibility_tree.append(node)
188
+ seen_ids.add(node["nodeId"])
189
+ accessibility_tree = _accessibility_tree
190
+
191
+ nodeid_to_cursor = {}
192
+ for cursor, node in enumerate(accessibility_tree):
193
+ nodeid_to_cursor[node["nodeId"]] = cursor
194
+ # usually because the node is not visible etc
195
+ if "backendDOMNodeId" not in node:
196
+ node["union_bound"] = None
197
+ continue
198
+ backend_node_id = str(node["backendDOMNodeId"])
199
+ if node["role"]["value"] == "RootWebArea":
200
+ # always inside the viewport
201
+ node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
202
+ else:
203
+ response = get_bounding_client_rect(
204
+ browser, backend_node_id
205
+ )
206
+ if response.get("result", {}).get("subtype", "") == "error":
207
+ node["union_bound"] = None
208
+ else:
209
+ x = response["result"]["value"]["x"]
210
+ y = response["result"]["value"]["y"]
211
+ width = response["result"]["value"]["width"]
212
+ height = response["result"]["value"]["height"]
213
+ node["union_bound"] = [x, y, width, height]
214
+
215
+ # filter nodes that are not in the current viewport
216
+ if current_viewport_only:
217
+
218
+ def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
219
+ # update the node information in the accessibility tree
220
+ nodeid = node["nodeId"]
221
+ node_cursor = nodeid_to_cursor[nodeid]
222
+ parent_nodeid = node["parentId"]
223
+ children_nodeids = node["childIds"]
224
+ parent_cursor = nodeid_to_cursor[parent_nodeid]
225
+ # update the children of the parent node
226
+ assert (
227
+ accessibility_tree[parent_cursor].get("parentId", "Root")
228
+ is not None
229
+ )
230
+ # remove the nodeid from parent's childIds
231
+ index = accessibility_tree[parent_cursor]["childIds"].index(
232
+ nodeid
233
+ )
234
+ accessibility_tree[parent_cursor]["childIds"].pop(index)
235
+ # Insert children_nodeids in the same location
236
+ for child_nodeid in children_nodeids:
237
+ accessibility_tree[parent_cursor]["childIds"].insert(
238
+ index, child_nodeid
239
+ )
240
+ index += 1
241
+ # update children node's parent
242
+ for child_nodeid in children_nodeids:
243
+ child_cursor = nodeid_to_cursor[child_nodeid]
244
+ accessibility_tree[child_cursor][
245
+ "parentId"
246
+ ] = parent_nodeid
247
+ # mark as removed
248
+ accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
249
+
250
+ config = info["config"]
251
+ for node in accessibility_tree:
252
+ if not node["union_bound"]:
253
+ remove_node_in_graph(node)
254
+ continue
255
+
256
+ [x, y, width, height] = node["union_bound"]
257
+
258
+ # invisible node
259
+ if width == 0 or height == 0:
260
+ remove_node_in_graph(node)
261
+ continue
262
+
263
+ in_viewport_ratio = get_element_in_viewport_ratio(
264
+ elem_left_bound=float(x),
265
+ elem_top_bound=float(y),
266
+ width=float(width),
267
+ height=float(height),
268
+ config=config,
269
+ )
270
+
271
+ if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
272
+ remove_node_in_graph(node)
273
+
274
+ accessibility_tree = [
275
+ node
276
+ for node in accessibility_tree
277
+ if node.get("parentId", "Root") != "[REMOVED]"
278
+ ]
279
+
280
+ return accessibility_tree
281
+
282
+
283
+ def parse_accessibility_tree(
284
+ accessibility_tree: AccessibilityTree,
285
+ ) -> tuple[str, dict[str, Any]]:
286
+ """Parse the accessibility tree into a string text"""
287
+ node_id_to_idx = {}
288
+ for idx, node in enumerate(accessibility_tree):
289
+ node_id_to_idx[node["nodeId"]] = idx
290
+
291
+ obs_nodes_info = {}
292
+
293
+ def dfs(idx: int, obs_node_id: str, depth: int) -> str:
294
+ tree_str = ""
295
+ node = accessibility_tree[idx]
296
+ indent = "\t" * depth
297
+ valid_node = True
298
+ try:
299
+ role = node["role"]["value"]
300
+ name = node["name"]["value"]
301
+ node_str = f"[{obs_node_id}] {role} {repr(name)}"
302
+ properties = []
303
+ for property in node.get("properties", []):
304
+ try:
305
+ if property["name"] in IGNORED_ACTREE_PROPERTIES:
306
+ continue
307
+ properties.append(
308
+ f'{property["name"]}: {property["value"]["value"]}'
309
+ )
310
+ except KeyError:
311
+ pass
312
+
313
+ if properties:
314
+ node_str += " " + " ".join(properties)
315
+
316
+ # check valid
317
+ if not node_str.strip():
318
+ valid_node = False
319
+
320
+ # empty generic node
321
+ if not name.strip():
322
+ if not properties:
323
+ if role in [
324
+ "generic",
325
+ "img",
326
+ "list",
327
+ "strong",
328
+ "paragraph",
329
+ "banner",
330
+ "navigation",
331
+ "Section",
332
+ "LabelText",
333
+ "Legend",
334
+ "listitem",
335
+ ]:
336
+ valid_node = False
337
+ elif role in ["listitem"]:
338
+ valid_node = False
339
+
340
+ if valid_node:
341
+ tree_str += f"{indent}{node_str}"
342
+ obs_nodes_info[obs_node_id] = {
343
+ "backend_id": node["backendDOMNodeId"],
344
+ "union_bound": node["union_bound"],
345
+ "text": node_str,
346
+ }
347
+
348
+ except:
349
+ valid_node = False
350
+
351
+ for _, child_node_id in enumerate(node["childIds"]):
352
+ if child_node_id not in node_id_to_idx:
353
+ continue
354
+ # mark this to save some tokens
355
+ child_depth = depth + 1 if valid_node else depth
356
+ child_str = dfs(
357
+ node_id_to_idx[child_node_id], child_node_id, child_depth
358
+ )
359
+ if child_str.strip():
360
+ if tree_str.strip():
361
+ tree_str += "\n"
362
+ tree_str += child_str
363
+
364
+ return tree_str
365
+
366
+ tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
367
+ return tree_str, obs_nodes_info
368
+
369
+
370
+ def clean_accesibility_tree(tree_str: str) -> str:
371
+ """further clean accesibility tree"""
372
+ clean_lines: list[str] = []
373
+ for line in tree_str.split("\n"):
374
+ if "statictext" in line.lower():
375
+ prev_lines = clean_lines[-3:]
376
+ pattern = r"\[\d+\] StaticText '([^']+)'"
377
+
378
+ match = re.search(pattern, line)
379
+ if match:
380
+ static_text = match.group(1)
381
+ if all(
382
+ static_text not in prev_line
383
+ for prev_line in prev_lines
384
+ ):
385
+ clean_lines.append(line)
386
+ else:
387
+ clean_lines.append(line)
388
+
389
+ return "\n".join(clean_lines)