multimodalart HF Staff commited on
Commit
b46c8a5
·
verified ·
1 Parent(s): 4f95f14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -31,9 +31,9 @@ try:
31
 
32
  model = AutoModelForImageTextToText.from_pretrained(
33
  MODEL_ID,
34
- torch_dtype="auto", # Uses torch.float16 if CUDA is available, else float32
35
- # attn_implementation=attn_implementation, # Enable if flash_attention_2 is installed and compatible
36
- device_map="auto", # Automatically uses CUDA if available
37
  trust_remote_code=True
38
  )
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
@@ -44,7 +44,6 @@ except Exception as e:
44
  "This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n" \
45
  "Ensure you have a stable internet connection and the necessary libraries installed."
46
  print(load_error_message)
47
- # Fallback for Gradio UI to show error
48
 
49
  # --- Helper functions from the model card (or adapted) ---
50
 
@@ -56,24 +55,20 @@ def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[di
56
  """
57
  guidelines: str = "Localize an element on the GUI image according to my instructions and output a click position as Click(x, y) with x num pixels from the left edge and y num pixels from the top edge."
58
 
59
- # The Qwen2-VL processor expects a list of dictionaries for messages.
60
- # For apply_chat_template, the image can be represented by its object if the template handles it,
61
- # or a placeholder. The Qwen processor inserts an image tag like <img></img>.
62
  return [
63
  {
64
  "role": "user",
65
  "content": [
66
  {
67
  "type": "image",
68
- "image": pil_image, # Passing the PIL image object here, as in the model card.
69
- # `apply_chat_template` will convert this to an image tag.
70
  },
71
  {"type": "text", "text": f"{guidelines}\n{instruction}"},
72
  ],
73
  }
74
  ]
75
 
76
- @spaces.GPU
77
  def run_inference_localization(
78
  current_model: AutoModelForImageTextToText,
79
  current_processor: AutoProcessor,
@@ -242,7 +237,7 @@ if not model_loaded:
242
  else:
243
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
244
  gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
245
- gr.Markdown(description)
246
 
247
  with gr.Row():
248
  with gr.Column(scale=1):
@@ -264,7 +259,7 @@ else:
264
  inputs=[input_image_component, instruction_component],
265
  outputs=[output_coords_component, output_image_component],
266
  fn=predict_click_location,
267
- cache_examples=False, # Re-run for dynamic examples if needed, but False is safer for resource limits
268
  )
269
 
270
  gr.Markdown(article)
 
31
 
32
  model = AutoModelForImageTextToText.from_pretrained(
33
  MODEL_ID,
34
+ torch_dtype=torch.bfloat16,
35
+ attn_implementation="flash_attention_2",
36
+ device_map="auto",
37
  trust_remote_code=True
38
  )
39
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
44
  "This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n" \
45
  "Ensure you have a stable internet connection and the necessary libraries installed."
46
  print(load_error_message)
 
47
 
48
  # --- Helper functions from the model card (or adapted) ---
49
 
 
55
  """
56
  guidelines: str = "Localize an element on the GUI image according to my instructions and output a click position as Click(x, y) with x num pixels from the left edge and y num pixels from the top edge."
57
 
 
 
 
58
  return [
59
  {
60
  "role": "user",
61
  "content": [
62
  {
63
  "type": "image",
64
+ "image": pil_image,
 
65
  },
66
  {"type": "text", "text": f"{guidelines}\n{instruction}"},
67
  ],
68
  }
69
  ]
70
 
71
+ @spaces.GPU(duration=120)
72
  def run_inference_localization(
73
  current_model: AutoModelForImageTextToText,
74
  current_processor: AutoProcessor,
 
237
  else:
238
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
239
  gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
240
+ # gr.Markdown(description)
241
 
242
  with gr.Row():
243
  with gr.Column(scale=1):
 
259
  inputs=[input_image_component, instruction_component],
260
  outputs=[output_coords_component, output_image_component],
261
  fn=predict_click_location,
262
+ cache_examples="lazy",
263
  )
264
 
265
  gr.Markdown(article)