Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,9 +31,9 @@ try:
|
|
| 31 |
|
| 32 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 33 |
MODEL_ID,
|
| 34 |
-
torch_dtype=
|
| 35 |
-
|
| 36 |
-
device_map="auto",
|
| 37 |
trust_remote_code=True
|
| 38 |
)
|
| 39 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
@@ -44,7 +44,6 @@ except Exception as e:
|
|
| 44 |
"This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n" \
|
| 45 |
"Ensure you have a stable internet connection and the necessary libraries installed."
|
| 46 |
print(load_error_message)
|
| 47 |
-
# Fallback for Gradio UI to show error
|
| 48 |
|
| 49 |
# --- Helper functions from the model card (or adapted) ---
|
| 50 |
|
|
@@ -56,24 +55,20 @@ def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[di
|
|
| 56 |
"""
|
| 57 |
guidelines: str = "Localize an element on the GUI image according to my instructions and output a click position as Click(x, y) with x num pixels from the left edge and y num pixels from the top edge."
|
| 58 |
|
| 59 |
-
# The Qwen2-VL processor expects a list of dictionaries for messages.
|
| 60 |
-
# For apply_chat_template, the image can be represented by its object if the template handles it,
|
| 61 |
-
# or a placeholder. The Qwen processor inserts an image tag like <img></img>.
|
| 62 |
return [
|
| 63 |
{
|
| 64 |
"role": "user",
|
| 65 |
"content": [
|
| 66 |
{
|
| 67 |
"type": "image",
|
| 68 |
-
"image": pil_image,
|
| 69 |
-
# `apply_chat_template` will convert this to an image tag.
|
| 70 |
},
|
| 71 |
{"type": "text", "text": f"{guidelines}\n{instruction}"},
|
| 72 |
],
|
| 73 |
}
|
| 74 |
]
|
| 75 |
|
| 76 |
-
@spaces.GPU
|
| 77 |
def run_inference_localization(
|
| 78 |
current_model: AutoModelForImageTextToText,
|
| 79 |
current_processor: AutoProcessor,
|
|
@@ -242,7 +237,7 @@ if not model_loaded:
|
|
| 242 |
else:
|
| 243 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 244 |
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
|
| 245 |
-
gr.Markdown(description)
|
| 246 |
|
| 247 |
with gr.Row():
|
| 248 |
with gr.Column(scale=1):
|
|
@@ -264,7 +259,7 @@ else:
|
|
| 264 |
inputs=[input_image_component, instruction_component],
|
| 265 |
outputs=[output_coords_component, output_image_component],
|
| 266 |
fn=predict_click_location,
|
| 267 |
-
cache_examples=
|
| 268 |
)
|
| 269 |
|
| 270 |
gr.Markdown(article)
|
|
|
|
| 31 |
|
| 32 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 33 |
MODEL_ID,
|
| 34 |
+
torch_dtype=torch.bfloat16,
|
| 35 |
+
attn_implementation="flash_attention_2",
|
| 36 |
+
device_map="auto",
|
| 37 |
trust_remote_code=True
|
| 38 |
)
|
| 39 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
|
|
| 44 |
"This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n" \
|
| 45 |
"Ensure you have a stable internet connection and the necessary libraries installed."
|
| 46 |
print(load_error_message)
|
|
|
|
| 47 |
|
| 48 |
# --- Helper functions from the model card (or adapted) ---
|
| 49 |
|
|
|
|
| 55 |
"""
|
| 56 |
guidelines: str = "Localize an element on the GUI image according to my instructions and output a click position as Click(x, y) with x num pixels from the left edge and y num pixels from the top edge."
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
return [
|
| 59 |
{
|
| 60 |
"role": "user",
|
| 61 |
"content": [
|
| 62 |
{
|
| 63 |
"type": "image",
|
| 64 |
+
"image": pil_image,
|
|
|
|
| 65 |
},
|
| 66 |
{"type": "text", "text": f"{guidelines}\n{instruction}"},
|
| 67 |
],
|
| 68 |
}
|
| 69 |
]
|
| 70 |
|
| 71 |
+
@spaces.GPU(duration=120)
|
| 72 |
def run_inference_localization(
|
| 73 |
current_model: AutoModelForImageTextToText,
|
| 74 |
current_processor: AutoProcessor,
|
|
|
|
| 237 |
else:
|
| 238 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 239 |
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
|
| 240 |
+
# gr.Markdown(description)
|
| 241 |
|
| 242 |
with gr.Row():
|
| 243 |
with gr.Column(scale=1):
|
|
|
|
| 259 |
inputs=[input_image_component, instruction_component],
|
| 260 |
outputs=[output_coords_component, output_image_component],
|
| 261 |
fn=predict_click_location,
|
| 262 |
+
cache_examples="lazy",
|
| 263 |
)
|
| 264 |
|
| 265 |
gr.Markdown(article)
|