doinglean commited on
Commit
c897aad
Β·
verified Β·
1 Parent(s): 9434573

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import logging
7
+
8
+ # Logging einrichten
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Lade das Modell und den Processor
13
+ model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-large", trust_remote_code=True)
14
+ processor = AutoProcessor.from_pretrained("microsoft/florence-2-large", trust_remote_code=True)
15
+
16
+ def analyze_image(image, prompt):
17
+ logger.info("Starting image analysis with prompt: %s", prompt)
18
+
19
+ # Konvertiere PIL-Bild zu numpy-Format
20
+ image_np = np.array(image)
21
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
22
+ logger.info("Image shape: %s", image_np.shape)
23
+
24
+ # Allgemeine Bildbeschreibung
25
+ if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
26
+ inputs = processor(text=prompt, images=image_np, return_tensors="pt")
27
+ with torch.no_grad():
28
+ outputs = model.generate(
29
+ input_ids=inputs["input_ids"],
30
+ pixel_values=inputs["pixel_values"],
31
+ max_length=1024,
32
+ num_beams=3
33
+ )
34
+ description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
35
+ return {"prompt": prompt, "description": description}
36
+
37
+ # Spezifische Kerzen-Analyse
38
+ elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
39
+ # Objekterkennung mit Florence-2
40
+ task_prompt = "<OD>" # Objekterkennung
41
+ inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
42
+ with torch.no_grad():
43
+ outputs = model.generate(
44
+ input_ids=inputs["input_ids"],
45
+ pixel_values=inputs["pixel_values"],
46
+ max_length=1024,
47
+ num_beams=3
48
+ )
49
+ predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
50
+ logger.info("Detected objects: %s", predictions)
51
+
52
+ # Extrahiere Kerzen
53
+ detections = []
54
+ if "<OD>" in predictions:
55
+ for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
56
+ if "candle" not in label.lower(): # Filtere nur Kerzen
57
+ continue
58
+ xmin, ymin, xmax, ymax = map(int, bbox)
59
+
60
+ # Extrahiere Farbe
61
+ candle_roi = image_cv[ymin:ymax, xmin:xmax]
62
+ if candle_roi.size == 0:
63
+ logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax)
64
+ continue
65
+ mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
66
+ color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
67
+
68
+ # OCR fΓΌr Preise (Achsen oder Kerzenregion)
69
+ price_roi = image_cv[max(0, ymin-50):min(image_np.shape[0], ymax+50),
70
+ max(0, xmin-50):min(image_np.shape[1], xmax+50)]
71
+ price_task = "<OCR>"
72
+ ocr_inputs = processor(text=price_task, images=price_roi, return_tensors="pt")
73
+ with torch.no_grad():
74
+ ocr_outputs = model.generate(
75
+ input_ids=ocr_inputs["input_ids"],
76
+ pixel_values=ocr_inputs["pixel_values"],
77
+ max_length=1024
78
+ )
79
+ prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
80
+
81
+ detections.append({
82
+ "pattern": label,
83
+ "color": color_rgb,
84
+ "prices": prices if prices else "No price detected",
85
+ "x_center": (xmin + xmax) / 2
86
+ })
87
+
88
+ # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
89
+ detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8]
90
+ logger.info("Sorted detections: %d", len(detections))
91
+
92
+ if not detections:
93
+ logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
94
+ return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot."}
95
+
96
+ return {"prompt": prompt, "detections": detections}
97
+
98
+ # Fallback fΓΌr unbekannte Prompts
99
+ else:
100
+ return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}
101
+
102
+ # Erstelle Gradio-Schnittstelle
103
+ iface = gr.Interface(
104
+ fn=analyze_image,
105
+ inputs=[
106
+ gr.Image(type="pil", label="Upload an Image"),
107
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'")
108
+ ],
109
+ outputs="json",
110
+ title="Image Analysis with Florence-2-large",
111
+ description="Upload an image and provide a prompt to get a description or analyze candlesticks."
112
+ )
113
+
114
+ iface.launch()