johnisafridge commited on
Commit
fb9ecd6
·
verified ·
1 Parent(s): a404749

Update app.py to include port and 0.0.0.0 in app launch

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import gradio as gr
2
  import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
4
  from qwen_vl_utils import process_vision_info
5
  import torch
6
  import base64
@@ -9,8 +10,15 @@ from io import BytesIO
9
  import re
10
 
11
 
 
 
 
12
  models = {
13
- "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained("OS-Copilot/OS-Atlas-Base-7B", torch_dtype="auto", device_map="auto"),
 
 
 
 
14
  }
15
 
16
  processors = {
@@ -18,14 +26,13 @@ processors = {
18
  }
19
 
20
 
21
- def image_to_base64(image):
22
  buffered = BytesIO()
23
  image.save(buffered, format="PNG")
24
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
25
- return img_str
26
 
27
 
28
- def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
29
  draw = ImageDraw.Draw(image)
30
  for box in bounding_boxes:
31
  xmin, ymin, xmax, ymax = box
@@ -39,13 +46,7 @@ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scal
39
  rescaled_boxes = []
40
  for box in bounding_boxes:
41
  xmin, ymin, xmax, ymax = box
42
- rescaled_box = [
43
- xmin * x_scale,
44
- ymin * y_scale,
45
- xmax * x_scale,
46
- ymax * y_scale
47
- ]
48
- rescaled_boxes.append(rescaled_box)
49
  return rescaled_boxes
50
 
51
 
@@ -53,7 +54,8 @@ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scal
53
  def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
54
  model = models[model_id].eval()
55
  processor = processors[model_id]
56
- prompt = f"In this UI screenshot, what is the position of the element corresponding to the command \"{text_input}\" (with bbox)?"
 
57
  messages = [
58
  {
59
  "role": "user",
@@ -64,9 +66,7 @@ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
64
  }
65
  ]
66
 
67
- text = processor.apply_chat_template(
68
- messages, tokenize=False, add_generation_prompt=True
69
- )
70
  image_inputs, video_inputs = process_vision_info(messages)
71
  inputs = processor(
72
  text=[text],
@@ -74,47 +74,60 @@ def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
74
  videos=video_inputs,
75
  padding=True,
76
  return_tensors="pt",
77
- )
78
- inputs = inputs.to("cuda")
79
 
80
  generated_ids = model.generate(**inputs, max_new_tokens=128)
81
- generated_ids_trimmed = [
82
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
- ]
84
  output_text = processor.batch_decode(
85
  generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
86
  )
87
- print(output_text)
88
  text = output_text[0]
89
 
 
90
  object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
91
  box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
92
 
93
- object_ref = re.search(object_ref_pattern, text).group(1)
94
- box_content = re.search(box_pattern, text).group(1)
 
 
 
95
 
96
- boxes = [tuple(map(int, pair.strip("()").split(','))) for pair in box_content.split("),(")]
97
- boxes = [[boxes[0][0], boxes[0][1], boxes[1][0], boxes[1][1]]]
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height)
100
- return object_ref, scaled_boxes, draw_bounding_boxes(image, scaled_boxes)
101
 
102
  css = """
103
  #output {
104
- height: 500px;
105
- overflow: auto;
106
- border: 1px solid #ccc;
107
  }
108
  """
 
109
  with gr.Blocks(css=css) as demo:
110
- gr.Markdown(
111
- """
112
- # Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents
113
- """)
114
  with gr.Row():
115
  with gr.Column():
116
  input_img = gr.Image(label="Input Image", type="pil")
117
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="OS-Copilot/OS-Atlas-Base-7B")
 
 
 
 
118
  text_input = gr.Textbox(label="User Prompt")
119
  submit_btn = gr.Button(value="Submit")
120
  with gr.Column():
@@ -131,9 +144,14 @@ with gr.Blocks(css=css) as demo:
131
  outputs=[model_output_text, model_output_box, annotated_image],
132
  fn=run_example,
133
  cache_examples=True,
134
- label="Try examples"
135
  )
136
 
137
- submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, model_output_box, annotated_image])
 
 
 
 
138
 
139
- demo.launch(debug=True)
 
 
1
+ import os
2
  import gradio as gr
3
  import spaces
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
  import torch
7
  import base64
 
10
  import re
11
 
12
 
13
+ # ---- HF Spaces: ensure we read the platform port ----
14
+ PORT = int(os.getenv("PORT", "7860"))
15
+
16
  models = {
17
+ "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "OS-Copilot/OS-Atlas-Base-7B",
19
+ torch_dtype="auto",
20
+ device_map="auto",
21
+ ),
22
  }
23
 
24
  processors = {
 
26
  }
27
 
28
 
29
+ def image_to_base64(image: Image.Image) -> str:
30
  buffered = BytesIO()
31
  image.save(buffered, format="PNG")
32
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
33
 
34
 
35
+ def draw_bounding_boxes(image: Image.Image, bounding_boxes, outline_color="red", line_width=2):
36
  draw = ImageDraw.Draw(image)
37
  for box in bounding_boxes:
38
  xmin, ymin, xmax, ymax = box
 
46
  rescaled_boxes = []
47
  for box in bounding_boxes:
48
  xmin, ymin, xmax, ymax = box
49
+ rescaled_boxes.append([xmin * x_scale, ymin * y_scale, xmax * x_scale, ymax * y_scale])
 
 
 
 
 
 
50
  return rescaled_boxes
51
 
52
 
 
54
  def run_example(image, text_input, model_id="OS-Copilot/OS-Atlas-Base-7B"):
55
  model = models[model_id].eval()
56
  processor = processors[model_id]
57
+
58
+ prompt = f'In this UI screenshot, what is the position of the element corresponding to the command "{text_input}" (with bbox)?'
59
  messages = [
60
  {
61
  "role": "user",
 
66
  }
67
  ]
68
 
69
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
70
  image_inputs, video_inputs = process_vision_info(messages)
71
  inputs = processor(
72
  text=[text],
 
74
  videos=video_inputs,
75
  padding=True,
76
  return_tensors="pt",
77
+ ).to("cuda")
 
78
 
79
  generated_ids = model.generate(**inputs, max_new_tokens=128)
80
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
 
 
81
  output_text = processor.batch_decode(
82
  generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
83
  )
 
84
  text = output_text[0]
85
 
86
+ # ---- simple, defensive parsing so the Space doesn't 500 if pattern not found ----
87
  object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
88
  box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
89
 
90
+ object_match = re.search(object_ref_pattern, text or "")
91
+ box_match = re.search(box_pattern, text or "")
92
+
93
+ object_ref = object_match.group(1) if object_match else ""
94
+ box_content = box_match.group(1) if box_match else ""
95
 
96
+ boxes = []
97
+ if box_content:
98
+ try:
99
+ parsed = [tuple(map(int, pair.strip("()").split(","))) for pair in box_content.split("),(")]
100
+ # expecting two points -> convert to [xmin, ymin, xmax, ymax]
101
+ if len(parsed) >= 2:
102
+ boxes = [[parsed[0][0], parsed[0][1], parsed[1][0], parsed[1][1]]]
103
+ except Exception:
104
+ boxes = []
105
+
106
+ scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height) if boxes else []
107
+ annotated = draw_bounding_boxes(image.copy(), scaled_boxes) if scaled_boxes else image
108
+
109
+ return object_ref, scaled_boxes, annotated
110
 
 
 
111
 
112
  css = """
113
  #output {
114
+ height: 500px;
115
+ overflow: auto;
116
+ border: 1px solid #ccc;
117
  }
118
  """
119
+
120
  with gr.Blocks(css=css) as demo:
121
+ gr.Markdown("# Demo for OS-ATLAS: A Foundation Action Model For Generalist GUI Agents")
122
+
 
 
123
  with gr.Row():
124
  with gr.Column():
125
  input_img = gr.Image(label="Input Image", type="pil")
126
+ model_selector = gr.Dropdown(
127
+ choices=list(models.keys()),
128
+ label="Model",
129
+ value="OS-Copilot/OS-Atlas-Base-7B"
130
+ )
131
  text_input = gr.Textbox(label="User Prompt")
132
  submit_btn = gr.Button(value="Submit")
133
  with gr.Column():
 
144
  outputs=[model_output_text, model_output_box, annotated_image],
145
  fn=run_example,
146
  cache_examples=True,
147
+ label="Try examples",
148
  )
149
 
150
+ submit_btn.click(
151
+ run_example,
152
+ [input_img, text_input, model_selector],
153
+ [model_output_text, model_output_box, annotated_image],
154
+ )
155
 
156
+ # ---- HF Spaces: bind to all interfaces + use provided port; disable API schema to avoid json-schema bug ----
157
+ demo.queue(api_open=False).launch(server_name="0.0.0.0", server_port=PORT, show_error=True, debug=True)