Spaces:
Runtime error
Runtime error
File size: 6,611 Bytes
c897aad 82434dd c897aad 440bc10 c897aad 82434dd 440bc10 82434dd c897aad 440bc10 c897aad 82434dd c897aad 440bc10 c897aad 440bc10 c897aad 440bc10 c897aad 440bc10 c897aad 440bc10 c897aad 440bc10 c897aad 440bc10 c897aad 784ebf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import numpy as np
import cv2
import logging
# Logging einrichten
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Lade das Modell und den Processor
try:
logger.info("Loading model: microsoft/florence-2-base")
model = AutoModelForCausalLM.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/florence-2-base", trust_remote_code=True)
logger.info("Model and processor loaded successfully")
except Exception as e:
logger.error("Failed to load model: %s", str(e))
raise
def analyze_image(image, prompt):
logger.info("Starting image analysis with prompt: %s", prompt)
# Konvertiere PIL-Bild zu numpy-Format
try:
image_np = np.array(image)
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
logger.info("Image shape: %s", image_np.shape)
except Exception as e:
logger.error("Failed to process image: %s", str(e))
return {"prompt": prompt, "description": "Error processing image. Ensure a valid image is uploaded."}
# Bildvorverarbeitung: Kontrast erhΓΆhen
try:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
enhanced = clahe.apply(gray)
image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
logger.info("Image preprocessing completed")
except Exception as e:
logger.warning("Failed to preprocess image: %s", str(e))
# Allgemeine Bildbeschreibung
if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
try:
inputs = processor(text=prompt, images=image_np, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_length=1024,
num_beams=3
)
description = processor.batch_decode(outputs, skip_special_tokens=True)[0]
return {"prompt": prompt, "description": description}
except Exception as e:
logger.error("Failed to generate description: %s", str(e))
return {"prompt": prompt, "description": "Error generating description. Try again with a clear image."}
# Kerzen-Analyse
elif "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
try:
task_prompt = "<OD>" # Objekterkennung
inputs = processor(text=task_prompt, images=image_np, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_length=1024,
num_beams=3
)
predictions = processor.post_process_generation(outputs, task=task_prompt, image_size=(image_np.shape[1], image_np.shape[0]))
logger.info("Detected objects: %s", predictions)
detections = []
if "<OD>" in predictions:
for i, (bbox, label) in enumerate(zip(predictions["<OD>"]["bboxes"], predictions["<OD>"]["labels"])):
# Erweitere Filter fΓΌr Kerzen
if "candle" not in label.lower() and "bar" not in label.lower() and "chart" not in label.lower() and "candlestick" not in label.lower():
continue
xmin, ymin, xmax, ymax = map(int, bbox)
# Extrahiere Farbe
candle_roi = image_cv[ymin:ymax, xmin:xmax]
if candle_roi.size == 0:
logger.warning("Empty ROI for box: (%d, %d, %d, %d)", xmin, ymin, xmax, ymax)
continue
mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
# OCR fΓΌr Preise (erweiterte ROI)
price_roi = image_cv[max(0, ymin-200):min(image_np.shape[0], ymax+200),
max(0, xmin-200):min(image_np.shape[1], xmax+200)]
ocr_inputs = processor(text="<OCR>", images=price_roi, return_tensors="pt")
with torch.no_grad():
ocr_outputs = model.generate(
input_ids=ocr_inputs["input_ids"],
pixel_values=ocr_inputs["pixel_values"],
max_length=1024
)
prices = processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]
detections.append({
"pattern": label,
"color": color_rgb,
"prices": prices if prices else "No price detected",
"x_center": (xmin + xmax) / 2
})
# Sortiere nach x-Position (rechts nach links = neueste Kerzen)
detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)[:8]
logger.info("Sorted detections: %d", len(detections))
if not detections:
logger.warning("No candlesticks detected. Ensure clear image with visible candles.")
return {"prompt": prompt, "description": "No candlesticks detected. Try a clearer screenshot with visible candles and prices."}
return {"prompt": prompt, "detections": detections}
except Exception as e:
logger.error("Failed to analyze candles: %s", str(e))
return {"prompt": prompt, "description": "Error analyzing candles. Try a clearer screenshot with visible candles and prices."}
else:
return {"prompt": prompt, "description": "Unsupported prompt. Use 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'."}
# Erstelle Gradio-Schnittstelle
iface = gr.Interface(
fn=analyze_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?' or 'List last 8 candles with their colors'")
],
outputs="json",
title="Image Analysis with Florence-2-base",
description="Upload an image to analyze candlesticks or get a general description."
)
iface.launch() |