Joseph Pollack commited on
Commit
81e328a
Β·
unverified Β·
1 Parent(s): 7dfb388

adds examples

Browse files
Files changed (1) hide show
  1. app.py +82 -14
app.py CHANGED
@@ -3,6 +3,8 @@ import torch
3
  from PIL import Image
4
  import json
5
  import os
 
 
6
  from transformers import AutoProcessor, AutoModelForImageTextToText
7
  from typing import List, Dict, Any
8
  import logging
@@ -130,7 +132,7 @@ class LOperatorDemo:
130
  return f"❌ Error generating action: {str(e)}"
131
 
132
  @spaces.GPU(duration=90) # 1.5 minutes for chat responses
133
- def chat_with_model(self, message: str, history: List[Dict[str, str]], image: Image.Image = None) -> List[Dict[str, str]]:
134
  """Chat interface function for Gradio"""
135
  if not self.is_loaded:
136
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Model not loaded. Please load the model first."}]
@@ -139,6 +141,19 @@ class LOperatorDemo:
139
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please upload an Android screenshot image."}]
140
 
141
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # Extract goal and instruction from message
143
  if "Goal:" in message and "Step:" in message:
144
  # Parse structured input
@@ -160,7 +175,7 @@ class LOperatorDemo:
160
  instruction = message
161
 
162
  # Generate action
163
- response = self.generate_action(image, goal, instruction)
164
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
165
 
166
  except Exception as e:
@@ -181,7 +196,42 @@ def load_model():
181
  logger.error(f"Error loading model: {str(e)}")
182
  return f"❌ Error loading model: {str(e)}"
183
 
184
- # Load example episodes (lazy loading to avoid startup timeout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def load_example_episodes():
186
  """Load example episodes from the extracted data - properly load images for Gradio"""
187
  examples = []
@@ -197,23 +247,28 @@ def load_example_episodes():
197
 
198
  # Check if both files exist
199
  if os.path.exists(metadata_path) and os.path.exists(image_path):
 
 
200
  with open(metadata_path, "r") as f:
201
  metadata = json.load(f)
202
 
203
  # Load the image using PIL
204
  image = Image.open(image_path)
205
 
206
- # Ensure image is in RGB mode
207
- if image.mode != "RGB":
208
- image = image.convert("RGB")
209
 
210
- episode_num = episode_dir.split('_')[1]
211
- goal_text = metadata.get('goal', f'Episode {episode_num} example')
212
-
213
- examples.append([
214
- image, # Use PIL Image object instead of file path
215
- f"Episode {episode_num}: {goal_text[:50]}..."
216
- ])
 
 
 
 
217
 
218
  except Exception as e:
219
  logger.warning(f"Could not load example for {episode_dir}: {str(e)}")
@@ -341,7 +396,20 @@ def create_demo():
341
  if not goal or not step:
342
  return {"error": "Please provide both goal and step"}
343
 
344
- response = demo_instance.generate_action(image, goal, step)
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  try:
347
  # Try to parse as JSON
 
3
  from PIL import Image
4
  import json
5
  import os
6
+ import base64
7
+ import io
8
  from transformers import AutoProcessor, AutoModelForImageTextToText
9
  from typing import List, Dict, Any
10
  import logging
 
132
  return f"❌ Error generating action: {str(e)}"
133
 
134
  @spaces.GPU(duration=90) # 1.5 minutes for chat responses
135
+ def chat_with_model(self, message: str, history: List[Dict[str, str]], image=None) -> List[Dict[str, str]]:
136
  """Chat interface function for Gradio"""
137
  if not self.is_loaded:
138
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Model not loaded. Please load the model first."}]
 
141
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please upload an Android screenshot image."}]
142
 
143
  try:
144
+ # Handle different image formats
145
+ pil_image = None
146
+ if isinstance(image, str) and image.startswith('data:image/'):
147
+ # Handle base64 image
148
+ pil_image = base64_to_pil(image)
149
+ elif hasattr(image, 'mode'): # PIL Image object
150
+ pil_image = image
151
+ else:
152
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Invalid image format. Please upload a valid image."}]
153
+
154
+ if pil_image is None:
155
+ return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Failed to process image. Please try again."}]
156
+
157
  # Extract goal and instruction from message
158
  if "Goal:" in message and "Step:" in message:
159
  # Parse structured input
 
175
  instruction = message
176
 
177
  # Generate action
178
+ response = self.generate_action(pil_image, goal, instruction)
179
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
180
 
181
  except Exception as e:
 
196
  logger.error(f"Error loading model: {str(e)}")
197
  return f"❌ Error loading model: {str(e)}"
198
 
199
+ def pil_to_base64(image):
200
+ """Convert PIL image to base64 string for Gradio examples"""
201
+ try:
202
+ # Convert to RGB if needed
203
+ if image.mode != "RGB":
204
+ image = image.convert("RGB")
205
+
206
+ # Save to bytes buffer
207
+ buffer = io.BytesIO()
208
+ image.save(buffer, format="PNG")
209
+ buffer.seek(0)
210
+
211
+ # Convert to base64
212
+ img_str = base64.b64encode(buffer.getvalue()).decode()
213
+ return f"data:image/png;base64,{img_str}"
214
+ except Exception as e:
215
+ logger.error(f"Error converting image to base64: {str(e)}")
216
+ return None
217
+
218
+ def base64_to_pil(base64_string):
219
+ """Convert base64 string to PIL image"""
220
+ try:
221
+ # Remove data URL prefix if present
222
+ if base64_string.startswith('data:image/'):
223
+ base64_string = base64_string.split(',')[1]
224
+
225
+ # Decode base64
226
+ image_data = base64.b64decode(base64_string)
227
+
228
+ # Create PIL image from bytes
229
+ image = Image.open(io.BytesIO(image_data))
230
+ return image
231
+ except Exception as e:
232
+ logger.error(f"Error converting base64 to PIL image: {str(e)}")
233
+ return None
234
+
235
  def load_example_episodes():
236
  """Load example episodes from the extracted data - properly load images for Gradio"""
237
  examples = []
 
247
 
248
  # Check if both files exist
249
  if os.path.exists(metadata_path) and os.path.exists(image_path):
250
+ logger.info(f"Loading example from {episode_dir}")
251
+
252
  with open(metadata_path, "r") as f:
253
  metadata = json.load(f)
254
 
255
  # Load the image using PIL
256
  image = Image.open(image_path)
257
 
258
+ # Convert to base64 for Gradio examples
259
+ base64_image = pil_to_base64(image)
 
260
 
261
+ if base64_image:
262
+ episode_num = episode_dir.split('_')[1]
263
+ goal_text = metadata.get('goal', f'Episode {episode_num} example')
264
+
265
+ examples.append([
266
+ base64_image, # Use base64 encoded image
267
+ f"Episode {episode_num}: {goal_text[:50]}..."
268
+ ])
269
+ logger.info(f"Successfully loaded example for Episode {episode_num}")
270
+ else:
271
+ logger.warning(f"Failed to convert image to base64 for {episode_dir}")
272
 
273
  except Exception as e:
274
  logger.warning(f"Could not load example for {episode_dir}: {str(e)}")
 
396
  if not goal or not step:
397
  return {"error": "Please provide both goal and step"}
398
 
399
+ # Handle different image formats
400
+ pil_image = None
401
+ if isinstance(image, str) and image.startswith('data:image/'):
402
+ # Handle base64 image
403
+ pil_image = base64_to_pil(image)
404
+ elif hasattr(image, 'mode'): # PIL Image object
405
+ pil_image = image
406
+ else:
407
+ return {"error": "Invalid image format. Please upload a valid image."}
408
+
409
+ if pil_image is None:
410
+ return {"error": "Failed to process image. Please try again."}
411
+
412
+ response = demo_instance.generate_action(pil_image, goal, step)
413
 
414
  try:
415
  # Try to parse as JSON