doinglean commited on
Commit
82434dd
Β·
verified Β·
1 Parent(s): 207c330

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -28
app.py CHANGED
@@ -10,8 +10,14 @@ logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  # Lade das Modell und den Processor
13
- model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-large", trust_remote_code=True)
14
- processor = AutoProcessor.from_pretrained("microsoft/florence-2-large", trust_remote_code=True)
 
 
 
 
 
 
15
 
16
  def analyze_image(image, prompt):
17
  logger.info("Starting image analysis with prompt: %s", prompt)
@@ -21,6 +27,12 @@ def analyze_image(image, prompt):
21
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
22
  logger.info("Image shape: %s", image_np.shape)
23
 
 
 
 
 
 
 
24
  # Allgemeine Bildbeschreibung
25
  if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
26
  inputs = processor(text=prompt, images=image_np, return_tensors="pt")
@@ -34,9 +46,8 @@ def analyze_image(image, prompt):
34
  description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
35
  return {"prompt": prompt, "description": description}
36
 
37
- # Spezifische Kerzen-Analyse
38
  elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
39
- # Objekterkennung mit Florence-2
40
  task_prompt = "<OD>" # Objekterkennung
41
  inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
42
  with torch.no_grad():
@@ -49,11 +60,11 @@ def analyze_image(image, prompt):
49
  predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
50
  logger.info("Detected objects: %s", predictions)
51
 
52
- # Extrahiere Kerzen
53
  detections = []
54
  if "<OD>" in predictions:
55
  for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
56
- if "candle" not in label.lower(): # Filtere nur Kerzen
 
57
  continue
58
  xmin, ymin, xmax, ymax = map(int, bbox)
59
 
@@ -65,11 +76,10 @@ def analyze_image(image, prompt):
65
  mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
66
  color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
67
 
68
- # OCR fΓΌr Preise (Achsen oder Kerzenregion)
69
- price_roi = image_cv[max(0, ymin-50):min(image_np.shape[0], ymax+50),
70
- max(0, xmin-50):min(image_np.shape[1], xmax+50)]
71
- price_task = "<OCR>"
72
- ocr_inputs = processor(text=price_task, images=price_roi, return_tensors="pt")
73
  with torch.no_grad():
74
  ocr_outputs = model.generate(
75
  input_ids=ocr_inputs["input_ids"],
@@ -91,24 +101,9 @@ def analyze_image(image, prompt):
91
 
92
  if not detections:
93
  logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
94
- return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot."}
95
 
96
  return {"prompt": prompt, "detections": detections}
97
 
98
- # Fallback fΓΌr unbekannte Prompts
99
  else:
100
- return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}
101
-
102
- # Erstelle Gradio-Schnittstelle
103
- iface = gr.Interface(
104
- fn=analyze_image,
105
- inputs=[
106
- gr.Image(type="pil", label="Upload an Image"),
107
- gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'")
108
- ],
109
- outputs="json",
110
- title="Image Analysis with Florence-2-large",
111
- description="Upload an image and provide a prompt to get a description or analyze candlesticks."
112
- )
113
-
114
- iface.launch()
 
10
  logger = logging.getLogger(__name__)
11
 
12
  # Lade das Modell und den Processor
13
+ try:
14
+ logger.info("Loading model: microsoft/florence-2-base")
15
+ model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
16
+ processor = AutoProcessor.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
17
+ logger.info("Model and processor loaded successfully")
18
+ except Exception as e:
19
+ logger.error("Failed to load model: %s", str(e))
20
+ raise
21
 
22
  def analyze_image(image, prompt):
23
  logger.info("Starting image analysis with prompt: %s", prompt)
 
27
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
28
  logger.info("Image shape: %s", image_np.shape)
29
 
30
+ # Bildvorverarbeitung: Kontrast erhΓΆhen
31
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
32
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
33
+ enhanced = clahe.apply(gray)
34
+ image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
35
+
36
  # Allgemeine Bildbeschreibung
37
  if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
38
  inputs = processor(text=prompt, images=image_np, return_tensors="pt")
 
46
  description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
47
  return {"prompt": prompt, "description": description}
48
 
49
+ # Kerzen-Analyse
50
  elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
 
51
  task_prompt = "<OD>" # Objekterkennung
52
  inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
53
  with torch.no_grad():
 
60
  predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
61
  logger.info("Detected objects: %s", predictions)
62
 
 
63
  detections = []
64
  if "<OD>" in predictions:
65
  for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
66
+ # Erweitere Filter fΓΌr Kerzen
67
+ if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower():
68
  continue
69
  xmin, ymin, xmax, ymax = map(int, bbox)
70
 
 
76
  mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
77
  color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
78
 
79
+ # OCR fΓΌr Preise (erweiterte ROI)
80
+ price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200),
81
+ max(0, xmin-200):min(image_np.shape[1], xmax+200)]
82
+ ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt")
 
83
  with torch.no_grad():
84
  ocr_outputs = model.generate(
85
  input_ids=ocr_inputs["input_ids"],
 
101
 
102
  if not detections:
103
  logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
104
+ return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."}
105
 
106
  return {"prompt": prompt, "detections": detections}
107
 
 
108
  else:
109
+ return {"prompt": prompt, "description": "Unsupported prompt