File size: 6,611 Bytes
c897aad
 
 
 
 
 
 
 
 
 
 
 
82434dd
 
 
 
 
 
 
 
c897aad
 
 
 
 
440bc10
 
 
 
 
 
 
c897aad
82434dd
440bc10
 
 
 
 
 
 
 
82434dd
c897aad
 
440bc10
 
 
 
 
 
 
 
 
 
 
 
 
 
c897aad
82434dd
c897aad
440bc10
 
 
 
 
 
 
 
 
 
 
 
c897aad
440bc10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c897aad
440bc10
 
 
 
 
 
 
 
 
 
 
c897aad
440bc10
 
 
 
 
 
c897aad
440bc10
 
 
c897aad
440bc10
 
 
c897aad
440bc10
 
 
 
c897aad
 
784ebf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import numpy as np
import cv2
import logging

# Logging einrichten
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Lade das Modell und den Processor
try:
    logger.info("Loading model: microsoft/florence-2-base")
    model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
    logger.info("Model and processor loaded successfully")
except Exception as e:
    logger.error("Failed to load model: %s", str(e))
    raise

def analyze_image(image, prompt):
    logger.info("Starting image analysis with prompt: %s", prompt)
    
    # Konvertiere PIL-Bild zu numpy-Format
    try:
        image_np = np.array(image)
        image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
        logger.info("Image shape: %s", image_np.shape)
    except Exception as e:
        logger.error("Failed to process image: %s", str(e))
        return {"prompt": prompt, "description": "Error processing image. Ensure a valid image is uploaded."}

    # Bildvorverarbeitung: Kontrast erhΓΆhen
    try:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
        enhanced = clahe.apply(gray)
        image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
        logger.info("Image preprocessing completed")
    except Exception as e:
        logger.warning("Failed to preprocess image: %s", str(e))

    # Allgemeine Bildbeschreibung
    if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
        try:
            inputs = processor(text=prompt, images=image_np, return_tensors="pt")
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=inputs["input_ids"],
                    pixel_values=inputs["pixel_values"],
                    max_length=1024,
                    num_beams=3
                )
            description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
            return {"prompt": prompt, "description": description}
        except Exception as e:
            logger.error("Failed to generate description: %s", str(e))
            return {"prompt": prompt, "description": "Error generating description. Try again with a clear image."}

    # Kerzen-Analyse
    elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
        try:
            task_prompt = "<OD>"  # Objekterkennung
            inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=inputs["input_ids"],
                    pixel_values=inputs["pixel_values"],
                    max_length=1024,
                    num_beams=3
                )
            predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
            logger.info("Detected objects: %s", predictions)

            detections = []
            if "<OD>" in predictions:
                for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
                    # Erweitere Filter fΓΌr Kerzen
                    if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower() and "candlestick" not in label.lower():
                        continue
                    xmin, ymin, xmax, ymax = map(int, bbox)
                    
                    # Extrahiere Farbe
                    candle_roi = image_cv[ymin:ymax, xmin:xmax]
                    if candle_roi.size == 0:
                        logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax)
                        continue
                    mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
                    color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"

                    # OCR fΓΌr Preise (erweiterte ROI)
                    price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200), 
                                        max(0, xmin-200):min(image_np.shape[1], xmax+200)]
                    ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt")
                    with torch.no_grad():
                        ocr_outputs = model.generate(
                            input_ids=ocr_inputs["input_ids"],
                            pixel_values=ocr_inputs["pixel_values"],
                            max_length=1024
                        )
                    prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]

                    detections.append({
                        "pattern": label,
                        "color": color_rgb,
                        "prices": prices if prices else "No price detected",
                        "x_center": (xmin + xmax) / 2
                    })

            # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
            detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8]
            logger.info("Sorted detections: %d", len(detections))

            if not detections:
                logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
                return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."}

            return {"prompt": prompt, "detections": detections}
        except Exception as e:
            logger.error("Failed to analyze candles: %s", str(e))
            return {"prompt": prompt, "description": "Error analyzing candles. Try a clearer screenshot with visible candles and prices."}

    else:
        return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}

# Erstelle Gradio-Schnittstelle
iface = gr.Interface(
    fn=analyze_image,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'")
    ],
    outputs="json",
    title="Image Analysis with Florence-2-base",
    description="Upload an image to analyze candlesticks or get a general description."
)

iface.launch()