Spaces:
Running
on
Zero
Running
on
Zero
Joseph Pollack
commited on
adds examples
Browse files
app.py
CHANGED
|
@@ -3,6 +3,8 @@ import torch
|
|
| 3 |
from PIL import Image
|
| 4 |
import json
|
| 5 |
import os
|
|
|
|
|
|
|
| 6 |
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 7 |
from typing import List, Dict, Any
|
| 8 |
import logging
|
|
@@ -130,7 +132,7 @@ class LOperatorDemo:
|
|
| 130 |
return f"β Error generating action: {str(e)}"
|
| 131 |
|
| 132 |
@spaces.GPU(duration=90) # 1.5 minutes for chat responses
|
| 133 |
-
def chat_with_model(self, message: str, history: List[Dict[str, str]], image
|
| 134 |
"""Chat interface function for Gradio"""
|
| 135 |
if not self.is_loaded:
|
| 136 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Model not loaded. Please load the model first."}]
|
|
@@ -139,6 +141,19 @@ class LOperatorDemo:
|
|
| 139 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Please upload an Android screenshot image."}]
|
| 140 |
|
| 141 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
# Extract goal and instruction from message
|
| 143 |
if "Goal:" in message and "Step:" in message:
|
| 144 |
# Parse structured input
|
|
@@ -160,7 +175,7 @@ class LOperatorDemo:
|
|
| 160 |
instruction = message
|
| 161 |
|
| 162 |
# Generate action
|
| 163 |
-
response = self.generate_action(
|
| 164 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
|
| 165 |
|
| 166 |
except Exception as e:
|
|
@@ -181,7 +196,42 @@ def load_model():
|
|
| 181 |
logger.error(f"Error loading model: {str(e)}")
|
| 182 |
return f"β Error loading model: {str(e)}"
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
def load_example_episodes():
|
| 186 |
"""Load example episodes from the extracted data - properly load images for Gradio"""
|
| 187 |
examples = []
|
|
@@ -197,23 +247,28 @@ def load_example_episodes():
|
|
| 197 |
|
| 198 |
# Check if both files exist
|
| 199 |
if os.path.exists(metadata_path) and os.path.exists(image_path):
|
|
|
|
|
|
|
| 200 |
with open(metadata_path, "r") as f:
|
| 201 |
metadata = json.load(f)
|
| 202 |
|
| 203 |
# Load the image using PIL
|
| 204 |
image = Image.open(image_path)
|
| 205 |
|
| 206 |
-
#
|
| 207 |
-
|
| 208 |
-
image = image.convert("RGB")
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
except Exception as e:
|
| 219 |
logger.warning(f"Could not load example for {episode_dir}: {str(e)}")
|
|
@@ -341,7 +396,20 @@ def create_demo():
|
|
| 341 |
if not goal or not step:
|
| 342 |
return {"error": "Please provide both goal and step"}
|
| 343 |
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
try:
|
| 347 |
# Try to parse as JSON
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
from transformers import AutoProcessor, AutoModelForImageTextToText
|
| 9 |
from typing import List, Dict, Any
|
| 10 |
import logging
|
|
|
|
| 132 |
return f"β Error generating action: {str(e)}"
|
| 133 |
|
| 134 |
@spaces.GPU(duration=90) # 1.5 minutes for chat responses
|
| 135 |
+
def chat_with_model(self, message: str, history: List[Dict[str, str]], image=None) -> List[Dict[str, str]]:
|
| 136 |
"""Chat interface function for Gradio"""
|
| 137 |
if not self.is_loaded:
|
| 138 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Model not loaded. Please load the model first."}]
|
|
|
|
| 141 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Please upload an Android screenshot image."}]
|
| 142 |
|
| 143 |
try:
|
| 144 |
+
# Handle different image formats
|
| 145 |
+
pil_image = None
|
| 146 |
+
if isinstance(image, str) and image.startswith('data:image/'):
|
| 147 |
+
# Handle base64 image
|
| 148 |
+
pil_image = base64_to_pil(image)
|
| 149 |
+
elif hasattr(image, 'mode'): # PIL Image object
|
| 150 |
+
pil_image = image
|
| 151 |
+
else:
|
| 152 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Invalid image format. Please upload a valid image."}]
|
| 153 |
+
|
| 154 |
+
if pil_image is None:
|
| 155 |
+
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "β Failed to process image. Please try again."}]
|
| 156 |
+
|
| 157 |
# Extract goal and instruction from message
|
| 158 |
if "Goal:" in message and "Step:" in message:
|
| 159 |
# Parse structured input
|
|
|
|
| 175 |
instruction = message
|
| 176 |
|
| 177 |
# Generate action
|
| 178 |
+
response = self.generate_action(pil_image, goal, instruction)
|
| 179 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
|
| 180 |
|
| 181 |
except Exception as e:
|
|
|
|
| 196 |
logger.error(f"Error loading model: {str(e)}")
|
| 197 |
return f"β Error loading model: {str(e)}"
|
| 198 |
|
| 199 |
+
def pil_to_base64(image):
|
| 200 |
+
"""Convert PIL image to base64 string for Gradio examples"""
|
| 201 |
+
try:
|
| 202 |
+
# Convert to RGB if needed
|
| 203 |
+
if image.mode != "RGB":
|
| 204 |
+
image = image.convert("RGB")
|
| 205 |
+
|
| 206 |
+
# Save to bytes buffer
|
| 207 |
+
buffer = io.BytesIO()
|
| 208 |
+
image.save(buffer, format="PNG")
|
| 209 |
+
buffer.seek(0)
|
| 210 |
+
|
| 211 |
+
# Convert to base64
|
| 212 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
| 213 |
+
return f"data:image/png;base64,{img_str}"
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"Error converting image to base64: {str(e)}")
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
def base64_to_pil(base64_string):
|
| 219 |
+
"""Convert base64 string to PIL image"""
|
| 220 |
+
try:
|
| 221 |
+
# Remove data URL prefix if present
|
| 222 |
+
if base64_string.startswith('data:image/'):
|
| 223 |
+
base64_string = base64_string.split(',')[1]
|
| 224 |
+
|
| 225 |
+
# Decode base64
|
| 226 |
+
image_data = base64.b64decode(base64_string)
|
| 227 |
+
|
| 228 |
+
# Create PIL image from bytes
|
| 229 |
+
image = Image.open(io.BytesIO(image_data))
|
| 230 |
+
return image
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.error(f"Error converting base64 to PIL image: {str(e)}")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
def load_example_episodes():
|
| 236 |
"""Load example episodes from the extracted data - properly load images for Gradio"""
|
| 237 |
examples = []
|
|
|
|
| 247 |
|
| 248 |
# Check if both files exist
|
| 249 |
if os.path.exists(metadata_path) and os.path.exists(image_path):
|
| 250 |
+
logger.info(f"Loading example from {episode_dir}")
|
| 251 |
+
|
| 252 |
with open(metadata_path, "r") as f:
|
| 253 |
metadata = json.load(f)
|
| 254 |
|
| 255 |
# Load the image using PIL
|
| 256 |
image = Image.open(image_path)
|
| 257 |
|
| 258 |
+
# Convert to base64 for Gradio examples
|
| 259 |
+
base64_image = pil_to_base64(image)
|
|
|
|
| 260 |
|
| 261 |
+
if base64_image:
|
| 262 |
+
episode_num = episode_dir.split('_')[1]
|
| 263 |
+
goal_text = metadata.get('goal', f'Episode {episode_num} example')
|
| 264 |
+
|
| 265 |
+
examples.append([
|
| 266 |
+
base64_image, # Use base64 encoded image
|
| 267 |
+
f"Episode {episode_num}: {goal_text[:50]}..."
|
| 268 |
+
])
|
| 269 |
+
logger.info(f"Successfully loaded example for Episode {episode_num}")
|
| 270 |
+
else:
|
| 271 |
+
logger.warning(f"Failed to convert image to base64 for {episode_dir}")
|
| 272 |
|
| 273 |
except Exception as e:
|
| 274 |
logger.warning(f"Could not load example for {episode_dir}: {str(e)}")
|
|
|
|
| 396 |
if not goal or not step:
|
| 397 |
return {"error": "Please provide both goal and step"}
|
| 398 |
|
| 399 |
+
# Handle different image formats
|
| 400 |
+
pil_image = None
|
| 401 |
+
if isinstance(image, str) and image.startswith('data:image/'):
|
| 402 |
+
# Handle base64 image
|
| 403 |
+
pil_image = base64_to_pil(image)
|
| 404 |
+
elif hasattr(image, 'mode'): # PIL Image object
|
| 405 |
+
pil_image = image
|
| 406 |
+
else:
|
| 407 |
+
return {"error": "Invalid image format. Please upload a valid image."}
|
| 408 |
+
|
| 409 |
+
if pil_image is None:
|
| 410 |
+
return {"error": "Failed to process image. Please try again."}
|
| 411 |
+
|
| 412 |
+
response = demo_instance.generate_action(pil_image, goal, step)
|
| 413 |
|
| 414 |
try:
|
| 415 |
# Try to parse as JSON
|