doinglean commited on
Commit
440bc10
·
verified ·
1 Parent(s): 2b81102

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -67
app.py CHANGED
@@ -23,87 +23,103 @@ def analyze_image(image, prompt):
23
  logger.info("Starting image analysis with prompt: %s", prompt)
24
 
25
  # Konvertiere PIL-Bild zu numpy-Format
26
- image_np = np.array(image)
27
- image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
28
- logger.info("Image shape: %s", image_np.shape)
 
 
 
 
29
 
30
  # Bildvorverarbeitung: Kontrast erhöhen
31
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
32
- gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
33
- enhanced = clahe.apply(gray)
34
- image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
 
 
 
 
35
 
36
  # Allgemeine Bildbeschreibung
37
  if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
38
- inputs = processor(text=prompt, images=image_np, return_tensors="pt")
39
- with torch.no_grad():
40
- outputs = model.generate(
41
- input_ids=inputs["input_ids"],
42
- pixel_values=inputs["pixel_values"],
43
- max_length=1024,
44
- num_beams=3
45
- )
46
- description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
47
- return {"prompt": prompt, "description": description}
 
 
 
 
48
 
49
  # Kerzen-Analyse
50
  elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
51
- task_prompt = "<OD>" # Objekterkennung
52
- inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
53
- with torch.no_grad():
54
- outputs = model.generate(
55
- input_ids=inputs["input_ids"],
56
- pixel_values=inputs["pixel_values"],
57
- max_length=1024,
58
- num_beams=3
59
- )
60
- predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
61
- logger.info("Detected objects: %s", predictions)
 
62
 
63
- detections = []
64
- if "<OD>" in predictions:
65
- for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
66
- # Erweitere Filter für Kerzen
67
- if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower() and "candlestick" not in label.lower():
68
- continue
69
- xmin, ymin, xmax, ymax = map(int, bbox)
70
-
71
- # Extrahiere Farbe
72
- candle_roi = image_cv[ymin:ymax, xmin:xmax]
73
- if candle_roi.size == 0:
74
- logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax)
75
- continue
76
- mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
77
- color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
78
 
79
- # OCR für Preise (erweiterte ROI)
80
- price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200),
81
- max(0, xmin-200):min(image_np.shape[1], xmax+200)]
82
- ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt")
83
- with torch.no_grad():
84
- ocr_outputs = model.generate(
85
- input_ids=ocr_inputs["input_ids"],
86
- pixel_values=ocr_inputs["pixel_values"],
87
- max_length=1024
88
- )
89
- prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
90
 
91
- detections.append({
92
- "pattern": label,
93
- "color": color_rgb,
94
- "prices": prices if prices else "No price detected",
95
- "x_center": (xmin + xmax) / 2
96
- })
97
 
98
- # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
99
- detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8]
100
- logger.info("Sorted detections: %d", len(detections))
101
 
102
- if not detections:
103
- logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
104
- return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."}
105
 
106
- return {"prompt": prompt, "detections": detections}
 
 
 
107
 
108
  else:
109
  return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}
 
23
  logger.info("Starting image analysis with prompt: %s", prompt)
24
 
25
  # Konvertiere PIL-Bild zu numpy-Format
26
+ try:
27
+ image_np = np.array(image)
28
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
29
+ logger.info("Image shape: %s", image_np.shape)
30
+ except Exception as e:
31
+ logger.error("Failed to process image: %s", str(e))
32
+ return {"prompt": prompt, "description": "Error processing image. Ensure a valid image is uploaded."}
33
 
34
  # Bildvorverarbeitung: Kontrast erhöhen
35
+ try:
36
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
37
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
38
+ enhanced = clahe.apply(gray)
39
+ image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
40
+ logger.info("Image preprocessing completed")
41
+ except Exception as e:
42
+ logger.warning("Failed to preprocess image: %s", str(e))
43
 
44
  # Allgemeine Bildbeschreibung
45
  if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
46
+ try:
47
+ inputs = processor(text=prompt, images=image_np, return_tensors="pt")
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ input_ids=inputs["input_ids"],
51
+ pixel_values=inputs["pixel_values"],
52
+ max_length=1024,
53
+ num_beams=3
54
+ )
55
+ description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
56
+ return {"prompt": prompt, "description": description}
57
+ except Exception as e:
58
+ logger.error("Failed to generate description: %s", str(e))
59
+ return {"prompt": prompt, "description": "Error generating description. Try again with a clear image."}
60
 
61
  # Kerzen-Analyse
62
  elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
63
+ try:
64
+ task_prompt = "<OD>" # Objekterkennung
65
+ inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
66
+ with torch.no_grad():
67
+ outputs = model.generate(
68
+ input_ids=inputs["input_ids"],
69
+ pixel_values=inputs["pixel_values"],
70
+ max_length=1024,
71
+ num_beams=3
72
+ )
73
+ predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
74
+ logger.info("Detected objects: %s", predictions)
75
 
76
+ detections = []
77
+ if "<OD>" in predictions:
78
+ for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
79
+ # Erweitere Filter für Kerzen
80
+ if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower() and "candlestick" not in label.lower():
81
+ continue
82
+ xmin, ymin, xmax, ymax = map(int, bbox)
83
+
84
+ # Extrahiere Farbe
85
+ candle_roi = image_cv[ymin:ymax, xmin:xmax]
86
+ if candle_roi.size == 0:
87
+ logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax)
88
+ continue
89
+ mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
90
+ color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
91
 
92
+ # OCR für Preise (erweiterte ROI)
93
+ price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200),
94
+ max(0, xmin-200):min(image_np.shape[1], xmax+200)]
95
+ ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt")
96
+ with torch.no_grad():
97
+ ocr_outputs = model.generate(
98
+ input_ids=ocr_inputs["input_ids"],
99
+ pixel_values=ocr_inputs["pixel_values"],
100
+ max_length=1024
101
+ )
102
+ prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
103
 
104
+ detections.append({
105
+ "pattern": label,
106
+ "color": color_rgb,
107
+ "prices": prices if prices else "No price detected",
108
+ "x_center": (xmin + xmax) / 2
109
+ })
110
 
111
+ # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
112
+ detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8]
113
+ logger.info("Sorted detections: %d", len(detections))
114
 
115
+ if not detections:
116
+ logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
117
+ return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."}
118
 
119
+ return {"prompt": prompt, "detections": detections}
120
+ except Exception as e:
121
+ logger.error("Failed to analyze candles: %s", str(e))
122
+ return {"prompt": prompt, "description": "Error analyzing candles. Try a clearer screenshot with visible candles and prices."}
123
 
124
  else:
125
  return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}