AUXteam's picture
Fix syntax error in element serialization wrapper
308928c verified
from gradio_client import Client
import os, json, re
class MCPWebDriver:
class ChromeOptions:
def __init__(self): pass
def add_argument(self, arg): pass
def add_experimental_option(self, name, value): pass
def __init__(self, space_id="diamond-in/Browser-Use-mcp"):
self.client = Client(space_id)
self.current_url = ""
self.width = 1920
self.height = 1080
def set_window_size(self, width, height):
self.width, self.height = width, height
def get(self, url):
self.current_url = url
self.client.predict(url, True, api_name="/get_page_info")
def save_screenshot(self, path):
res = self.client.predict(self.current_url, False, True, api_name="/screenshot")
if isinstance(res, dict) and 'path' in res:
os.rename(res['path'], path)
elif isinstance(res, str) and os.path.exists(res):
os.rename(res, path)
def execute_script(self, script, *args):
# Robust wrapper to handle element serialization and arbitrary script execution
wrapper_js = """
(function() {
const wrapResult = (obj) => {
if (!obj) return obj;
if (obj instanceof Element) {
let id = obj.getAttribute('data-mcp-id');
if (!id) {
id = 'mcp-' + Math.random().toString(36).substr(2, 9);
obj.setAttribute('data-mcp-id', id);
}
return { __mcp_type: 'element', selector: '[data-mcp-id="' + id + '"]', tag_name: obj.tagName.toLowerCase() };
}
if (Array.isArray(obj)) return obj.map(wrapResult);
if (typeof obj === 'object') {
const newObj = {};
for (let key in obj) {
try { newObj[key] = wrapResult(obj[key]); }
catch(e) { newObj[key] = null; }
}
return newObj;
}
return obj;
};
const userScript = // USER_SCRIPT_HERE;
const result = (function() {
// USER_SCRIPT_HERE
})();
return wrapResult(result);
})()
"""
# We need to escape the user script for inclusion in a JS string literal
# But wait, if I put it in the function body directly it's easier.
# If the script has its own return, it will return from the inner function.
wrapper_template = """
(function() {
const wrapResult = (obj) => {
if (!obj) return obj;
if (typeof Element !== 'undefined' && obj instanceof Element) {
let id = obj.getAttribute('data-mcp-id');
if (!id) {
id = 'mcp-' + Math.random().toString(36).substr(2, 9);
obj.setAttribute('data-mcp-id', id);
}
return { __mcp_type: 'element', selector: '[data-mcp-id="' + id + '"]', tag_name: obj.tagName.toLowerCase() };
}
if (Array.isArray(obj)) return obj.map(wrapResult);
if (typeof obj === 'object' && obj !== null) {
const newObj = {};
for (let key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
try { newObj[key] = wrapResult(obj[key]); }
catch(e) { }
}
}
return newObj;
}
return obj;
};
const result = (function() {
// USER_SCRIPT_HERE
})();
return wrapResult(result);
})()
"""
final_script = wrapper_template.replace("// USER_SCRIPT_HERE", script)
try:
res = self.client.predict(self.current_url, final_script, True, api_name="/execute_js")
# Result from server is expected to be JSON string
try:
data = json.loads(res)
except:
if res.startswith("Script executed. Result: "):
res = res[len("Script executed. Result: "):]
try: data = json.loads(res)
except: data = res
return self._reconstruct_elements(data)
except Exception as e:
print(f"Error executing script: {e}")
# Fallback to avoid unpacking error if possible, but better to let it fail or return empty list
return []
def _reconstruct_elements(self, data):
if isinstance(data, dict):
if data.get('__mcp_type') == 'element':
return MCPWebElement(self, data['selector'], data.get('tag_name', 'div'))
return {k: self._reconstruct_elements(v) for k, v in data.items()}
if isinstance(data, list):
return [self._reconstruct_elements(v) for v in data]
return data
def quit(self): self.client.predict(api_name="/close_persistent_driver")
def find_elements(self, by, value): return []
@property
def title(self):
res = self.client.predict(self.current_url, True, api_name="/get_page_info")
try: return json.loads(res).get("title", "")
except: return ""
class MCPWebElement:
def __init__(self, driver, selector, tag_name="div"):
self.driver, self.selector, self.tag_name = driver, selector, tag_name
def click(self): self.driver.client.predict(self.driver.current_url, self.selector, True, api_name="/click")
def send_keys(self, *keys):
text = "".join([str(k) for k in keys if isinstance(k, str)])
self.driver.client.predict(self.driver.current_url, self.selector, text, True, api_name="/fill")
def get_attribute(self, name): return ""
def clear(self): pass
class ActionChains:
def __init__(self, driver): self.driver = driver
def click(self, element=None):
if element: element.click()
return self
def send_keys(self, *keys):
text = "".join([str(k) for k in keys if isinstance(k, str)])
self.driver.client.predict(self.driver.current_url, "body", text, True, api_name="/fill")
return self
def perform(self): pass
def pause(self, seconds): return self
def key_down(self, key): return self
def key_up(self, key): return self