Spaces:
Sleeping
Sleeping
File size: 7,488 Bytes
8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 fbcb44a 8f1d3f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import gradio as gr
import torch
import sys
import os
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download
# Import spaces for GPU support on Hugging Face Spaces
try:
import spaces
HF_SPACES = True
except ImportError:
HF_SPACES = False
# Create a dummy decorator if not on Spaces
def spaces_gpu_decorator(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
# Add parent directory to path to import our modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from training.model import CrossAttentionModel
from transformers import BertTokenizer, ViTImageProcessor
class CandleFusionDemo:
def __init__(self, model_path=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model from Hugging Face
self.model = CrossAttentionModel()
try:
# Download model from Hugging Face Hub
print("๐ฅ Downloading model from Hugging Face...")
model_file = hf_hub_download(
repo_id="tuankg1028/candlefusion",
filename="pytorch_model.bin",
cache_dir="./model_cache"
)
# Load the downloaded model
self.model.load_state_dict(torch.load(model_file, map_location=self.device))
print(f"โ
Model loaded from Hugging Face: tuankg1028/candlefusion")
except Exception as e:
print(f"โ Error loading model from Hugging Face: {str(e)}")
print("โ ๏ธ Using untrained model instead.")
self.model.to(self.device)
self.model.eval()
# Initialize processors
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
# Class labels
self.class_labels = ["Bearish", "Bullish"]
def preprocess_inputs(self, image):
"""Preprocess image input for the model"""
# Process image
if image is None:
raise ValueError("Please upload a candlestick chart image")
image = Image.fromarray(image).convert("RGB")
image_inputs = self.processor(images=image, return_tensors="pt")
pixel_values = image_inputs["pixel_values"].to(self.device)
# Process text with default value
text = "Market analysis" # Default text
text_inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=64
)
input_ids = text_inputs["input_ids"].to(self.device)
attention_mask = text_inputs["attention_mask"].to(self.device)
return pixel_values, input_ids, attention_mask
@spaces.GPU
def predict(self, image):
"""Make prediction using the model"""
try:
# Preprocess inputs
pixel_values, input_ids, attention_mask = self.preprocess_inputs(image)
# Model prediction
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values
)
logits = outputs["logits"]
forecast = outputs["forecast"]
# Get classification results
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(logits, dim=1).item()
confidence = probabilities[0][predicted_class].item()
# Get price forecast
predicted_price = forecast.squeeze().item()
# Format results
classification_result = f"**Prediction:** {self.class_labels[predicted_class]}\n"
classification_result += f"**Confidence:** {confidence:.2%}\n\n"
classification_result += "**Class Probabilities:**\n"
for i, (label, prob) in enumerate(zip(self.class_labels, probabilities[0])):
classification_result += f"- {label}: {prob:.2%}\n"
forecast_result = f"**Predicted Next Close Price:** ${predicted_price:.2f}"
return classification_result, forecast_result
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
return error_msg, error_msg
def create_demo():
"""Create and launch the Gradio demo"""
demo_instance = CandleFusionDemo()
# Create Gradio interface
with gr.Blocks(title="CandleFusion - Candlestick Chart Analysis", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐ฏ๏ธ CandleFusion Demo
Upload a candlestick chart image to get:
- **Market Direction Prediction** (Bullish/Bearish)
- **Next Close Price Forecast**
This model analyzes candlestick charts using BERT + ViT architecture.
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ๐ Input")
image_input = gr.Image(
label="Candlestick Chart",
type="numpy",
height=300
)
predict_btn = gr.Button("๐ฎ Analyze Chart", variant="primary")
gr.Markdown("""
### ๐ก Tips:
- Upload clear candlestick chart images
- Charts should show recent price action
""")
with gr.Column(scale=1):
gr.Markdown("### ๐ Results")
classification_output = gr.Markdown(
value="Upload an image and click 'Analyze Chart' to see prediction"
)
forecast_output = gr.Markdown(
value=""
)
# Example section
gr.Markdown("### ๐ Example")
gr.Examples(
examples=[
["examples/example_chart.png"],
["examples/example_chart2.png"]
],
inputs=[image_input],
label="Try these examples:"
)
# Connect the prediction function
predict_btn.click(
fn=demo_instance.predict,
inputs=[image_input],
outputs=[classification_output, forecast_output]
)
gr.Markdown("""
---
**Note:** This is a demo model. For production trading decisions, always consult with financial professionals and use additional analysis tools.
""")
return demo
def main():
"""Main function to launch the demo"""
try:
demo = create_demo()
# Launch with server_name for compatibility on HF Spaces
demo.launch(server_name="0.0.0.0")
except Exception as e:
print(f"Failed to launch Gradio demo: {e}")
# Fallback launch with minimal configuration
demo = create_demo()
demo.launch(server_name="0.0.0.0")
if __name__ == "__main__":
main()
|