doinglean commited on
Commit
dee9ee3
·
verified ·
1 Parent(s): 9b9b214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForObjectDetection, AutoImageProcessor
3
  import cv2
4
  import numpy as np
5
  import torch
@@ -14,7 +14,7 @@ else:
14
  raise ValueError("HF_TOKEN environment variable not set. Please add it in Space settings.")
15
 
16
  # Lade das Modell und den Image Processor
17
- model = AutoModelForObjectDetection.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
18
  image_processor = AutoImageProcessor.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
19
 
20
  def analyze_image(image, prompt):
@@ -22,53 +22,54 @@ def analyze_image(image, prompt):
22
  image_np = np.array(image)
23
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
24
 
25
- # Vorbereitung für DINOv3
26
  inputs = image_processor(images=image_np, return_tensors="pt")
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
 
30
- # Extrahiere Bounding-Boxes und Labels
31
- detections = []
32
- for score, label, box in zip(outputs.logits.softmax(-1).max(-1)[0], outputs.logits.softmax(-1).max(-1)[1], outputs.pred_boxes):
33
- if score > 0.5: # Confidence-Schwellenwert
34
- # Konvertiere Box-Koordinaten
35
- box = box * torch.tensor([image_np.shape[1], image_np.shape[0], image_np.shape[1], image_np.shape[0]])
36
- xmin, ymin, xmax, ymax = box.int().tolist()
37
 
38
- # Schneide die Kerze für Farbanalyse
39
- candle_roi = image_cv[ymin:ymax, xmin:xmax]
40
- if candle_roi.size == 0: # Vermeide leere ROIs
41
- continue
42
- mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
43
- color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- detections.append({
46
- "pattern": f"Candlestick_{label.item()}", # Generische Labels
47
- "confidence": score.item(),
48
- "color": color_rgb,
49
- "x_center": (xmin + xmax) / 2,
50
- "prompt_used": prompt
51
- })
52
-
53
- # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
54
- detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)
55
-
56
- # Begrenze auf die letzten 10 Kerzen, wenn im Prompt gefordert
57
- if "last 10 candles" in prompt.lower():
58
- detections = detections[:10]
59
-
60
- return detections
61
 
62
  # Erstelle Gradio-Schnittstelle
63
  iface = gr.Interface(
64
  fn=analyze_image,
65
  inputs=[
66
- gr.Image(type="pil", label="Upload TradingView Screenshot"),
67
- gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'List last 10 candles with their colors'")
68
  ],
69
  outputs="json",
70
- title="Candlestick Pattern Detection with DINOv3",
71
- description="Upload a TradingView screenshot and provide a prompt to detect candlestick patterns and colors."
72
  )
73
 
74
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoImageProcessor
3
  import cv2
4
  import numpy as np
5
  import torch
 
14
  raise ValueError("HF_TOKEN environment variable not set. Please add it in Space settings.")
15
 
16
  # Lade das Modell und den Image Processor
17
+ model = AutoModel.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
18
  image_processor = AutoImageProcessor.from_pretrained("facebook/dinov3-convnext-small-pretrain-lvd1689m")
19
 
20
  def analyze_image(image, prompt):
 
22
  image_np = np.array(image)
23
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
24
 
25
+ # Extrahiere Features mit DINOv3
26
  inputs = image_processor(images=image_np, return_tensors="pt")
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
 
30
+ # Einfache Bildanalyse mit OpenCV (Konturen für Objekte)
31
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
32
+ _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
33
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
34
 
35
+ # Analysiere das Bild basierend auf dem Prompt
36
+ description = []
37
+ if "what do you see" in prompt.lower() or "was siehst du" in prompt.lower():
38
+ if len(contours) == 0:
39
+ description.append("Das Bild enthält keine klar erkennbaren Objekte.")
40
+ else:
41
+ for idx, contour in enumerate(contours[:10]): # Begrenze auf 10 Objekte
42
+ if cv2.contourArea(contour) < 100: # Ignoriere kleine Konturen
43
+ continue
44
+ x, y, w, h = cv2.boundingRect(contour)
45
+ # Extrahiere Farbe der Region
46
+ roi = image_cv[y:y+h, x:x+w]
47
+ if roi.size == 0:
48
+ continue
49
+ mean_color = np.mean(roi, axis=(0, 1)).astype(int)
50
+ color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
51
+ description.append({
52
+ "object": f"Object_{idx}",
53
+ "color": color_rgb,
54
+ "position": f"x={x}, y={y}, width={w}, height={h}"
55
+ })
56
 
57
+ return {
58
+ "prompt": prompt,
59
+ "description": description if description else "No objects detected.",
60
+ "features_shape": str(outputs.last_hidden_state.shape) if hasattr(outputs, 'last_hidden_state') else "No features extracted."
61
+ }
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Erstelle Gradio-Schnittstelle
64
  iface = gr.Interface(
65
  fn=analyze_image,
66
  inputs=[
67
+ gr.Image(type="pil", label="Upload an Image"),
68
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Was siehst du auf dem Bild?'")
69
  ],
70
  outputs="json",
71
+ title="General Image Analysis with DINOv3",
72
+ description="Upload an image and provide a prompt to get a description of what the model sees."
73
  )
74
 
75
  iface.launch()