doinglean commited on
Commit
e249cba
·
verified ·
1 Parent(s): e265812

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -54
app.py CHANGED
@@ -1,82 +1,101 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoImageProcessor
 
3
  import cv2
4
  import numpy as np
5
- import torch
6
- import os
7
- from huggingface_hub import login
8
 
9
- # Authentifizierung mit Hugging Face API-Token
10
- HF_TOKEN = os.getenv("HF_TOKEN")
11
- if HF_TOKEN:
12
- login(HF_TOKEN)
13
- else:
14
- raise ValueError("HF_TOKEN environment variable not set. Please add it in Space settings.")
15
 
16
- # Lade das Modell und den Image Processor
17
- model = AutoModel.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
18
- image_processor = AutoImageProcessor.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
 
 
 
19
 
20
  def analyze_image(image, prompt):
 
 
21
  # Konvertiere PIL-Bild zu OpenCV-Format
22
  image_np = np.array(image)
23
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
24
 
25
- # Extrahiere Features mit DINOv3
26
- inputs = image_processor(images=image_np, return_tensors="pt")
27
- with torch.no_grad():
28
- outputs = model(**inputs)
29
 
30
- # Verbesserte Bildanalyse mit OpenCV
31
- gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
32
- # Adaptiver Schwellenwert für bessere Konturenerkennung
33
- thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
34
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Analysiere das Bild basierend auf dem Prompt
37
- description = []
38
- if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
39
- if len(contours) == 0:
40
- description.append("Das Bild enthält keine klar erkennbaren Objekte.")
41
- else:
42
- for idx, contour in enumerate(contours[:20]): # Begrenze auf 20 Objekte
43
- if cv2.contourArea(contour) < 200 or cv2.contourArea(contour) > (image_np.shape[0] * image_np.shape[1] * 0.5): # Filtere kleine/große Konturen
44
- continue
45
- x, y, w, h = cv2.boundingRect(contour)
46
- # Extrahiere Farbe der Region
47
- roi = image_cv[y:y+h, x:x+w]
48
- if roi.size == 0:
49
- continue
50
- mean_color = np.mean(roi, axis=(0, 1)).astype(int)
51
- color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
52
- # Größenkategorie
53
- size = "small" if w * h < 1000 else "medium" if w * h < 5000 else "large"
54
- description.append({
55
- "object": f"Object_{idx}",
56
- "color": color_rgb,
57
- "position": f"x={x}, y={y}, width={w}, height={h}",
58
- "size": size
59
- })
60
 
61
- # Einfache Analyse der DINOv3-Features (z. B. Anzahl der Feature-Regionen)
62
- feature_info = str(outputs.last_hidden_state.shape) if hasattr(outputs, 'last_hidden_state') else "No features extracted."
 
63
 
 
 
 
 
 
64
  return {
65
  "prompt": prompt,
66
- "description": description if description else "No objects detected.",
67
- "features_shape": feature_info
68
  }
69
 
70
  # Erstelle Gradio-Schnittstelle
71
  iface = gr.Interface(
72
  fn=analyze_image,
73
  inputs=[
74
- gr.Image(type="pil", label="Upload an Image"),
75
- gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?'")
76
  ],
77
  outputs="json",
78
- title="General Image Analysis with DINOv3",
79
- description="Upload an image and provide a prompt to get a description of what the model sees."
80
  )
81
 
82
  iface.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from ultralytics import YOLO
4
  import cv2
5
  import numpy as np
6
+ import easyocr
7
+ import logging
 
8
 
9
+ # Logging einrichten
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
 
 
 
12
 
13
+ # Lade das Modell
14
+ model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
15
+ model = YOLO(model_path)
16
+
17
+ # OCR für Preise
18
+ reader = easyocr.Reader(['en'], gpu=False)
19
 
20
  def analyze_image(image, prompt):
21
+ logger.info("Starting image analysis with prompt: %s", prompt)
22
+
23
  # Konvertiere PIL-Bild zu OpenCV-Format
24
  image_np = np.array(image)
25
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
26
+
27
+ # Bildvorverarbeitung: Kontrast erhöhen
28
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
29
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
30
+ enhanced = clahe.apply(gray)
31
+ image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
32
+ logger.info("Image preprocessed: shape=%s", image_np.shape)
33
 
34
+ # Führe Objekterkennung durch
35
+ results = model.predict(source=image_np, conf=0.3, iou=0.5, save=False)
36
+ logger.info("YOLO predictions: %d boxes detected", len(results[0].boxes))
 
37
 
38
+ # Extrahiere Kerzen
39
+ detections = []
40
+ for result in results:
41
+ for box in result.boxes:
42
+ label = result.names[int(box.cls)]
43
+ confidence = float(box.conf)
44
+ xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
45
+ logger.info("Detected: %s, confidence=%.2f, box=(%.0f, %.0f, %.0f, %.0f)",
46
+ label, confidence, xmin, ymin, xmax, ymax)
47
+
48
+ # Extrahiere Farbe (Fokus auf Kerzenkörper)
49
+ candle_roi = image_cv[int(ymin):int(ymax), int(xmin):int(xmax)]
50
+ if candle_roi.size == 0:
51
+ logger.warning("Empty ROI for box: (%.0f, %.0f, %.0f, %.0f)", xmin, ymin, xmax, ymax)
52
+ continue
53
+ mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
54
+ color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
55
+
56
+ # OCR für Preise (erweitere ROI für Achsen)
57
+ price_roi = image_cv[max(0, int(ymin)-50):min(image_np.shape[0], int(ymax)+50),
58
+ max(0, int(xmin)-50):min(image_np.shape[1], int(xmax)+50)]
59
+ price_text = reader.readtext(price_roi, detail=0, allowlist='0123456789.')
60
+ prices = ' '.join(price_text) if price_text else "No price detected"
61
+ logger.info("OCR prices: %s", prices)
62
+
63
+ detections.append({
64
+ "pattern": label,
65
+ "confidence": confidence,
66
+ "color": color_rgb,
67
+ "prices": prices,
68
+ "x_center": (xmin + xmax) / 2
69
+ })
70
 
71
+ # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
72
+ detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)
73
+ logger.info("Sorted detections: %d", len(detections))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Begrenze auf die letzten 8 Kerzen
76
+ if "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
77
+ detections = detections[:8]
78
 
79
+ # Debugging: Wenn leer, gib Hinweis
80
+ if not detections:
81
+ logger.warning("No detections found. Check image quality or model configuration.")
82
+ return {"prompt": prompt, "description": "No candlesticks detected. Ensure clear image and visible candles."}
83
+
84
  return {
85
  "prompt": prompt,
86
+ "detections": detections
 
87
  }
88
 
89
  # Erstelle Gradio-Schnittstelle
90
  iface = gr.Interface(
91
  fn=analyze_image,
92
  inputs=[
93
+ gr.Image(type="pil", label="Upload TradingView Screenshot"),
94
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'List last 8 candles with their colors'")
95
  ],
96
  outputs="json",
97
+ title="Stock Chart Analysis with YOLOv8",
98
+ description="Upload a TradingView screenshot to detect the last 8 candlesticks, their colors, and prices."
99
  )
100
 
101
  iface.launch()