File size: 7,488 Bytes
8f1d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb44a
 
8f1d3f9
 
 
 
 
 
 
 
fbcb44a
 
8f1d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb44a
8f1d3f9
 
 
fbcb44a
8f1d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb44a
8f1d3f9
 
 
fbcb44a
8f1d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbcb44a
 
8f1d3f9
fbcb44a
8f1d3f9
 
 
 
 
 
fbcb44a
8f1d3f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import torch
import sys
import os
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download

# Import spaces for GPU support on Hugging Face Spaces
try:
    import spaces
    HF_SPACES = True
except ImportError:
    HF_SPACES = False
    # Create a dummy decorator if not on Spaces
    def spaces_gpu_decorator(func):
        return func
    spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()

# Add parent directory to path to import our modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from training.model import CrossAttentionModel
from transformers import BertTokenizer, ViTImageProcessor

class CandleFusionDemo:
    def __init__(self, model_path=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load model from Hugging Face
        self.model = CrossAttentionModel()
        
        try:
            # Download model from Hugging Face Hub
            print("๐Ÿ“ฅ Downloading model from Hugging Face...")
            model_file = hf_hub_download(
                repo_id="tuankg1028/candlefusion",
                filename="pytorch_model.bin",
                cache_dir="./model_cache"
            )
            
            # Load the downloaded model
            self.model.load_state_dict(torch.load(model_file, map_location=self.device))
            print(f"โœ… Model loaded from Hugging Face: tuankg1028/candlefusion")
            
        except Exception as e:
            print(f"โŒ Error loading model from Hugging Face: {str(e)}")
            print("โš ๏ธ Using untrained model instead.")
        
        self.model.to(self.device)
        self.model.eval()
        
        # Initialize processors
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        
        # Class labels
        self.class_labels = ["Bearish", "Bullish"]
    
    def preprocess_inputs(self, image):
        """Preprocess image input for the model"""
        # Process image
        if image is None:
            raise ValueError("Please upload a candlestick chart image")
        
        image = Image.fromarray(image).convert("RGB")
        image_inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = image_inputs["pixel_values"].to(self.device)
        
        # Process text with default value
        text = "Market analysis"  # Default text
        
        text_inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=64
        )
        input_ids = text_inputs["input_ids"].to(self.device)
        attention_mask = text_inputs["attention_mask"].to(self.device)
        
        return pixel_values, input_ids, attention_mask
    
    @spaces.GPU
    def predict(self, image):
        """Make prediction using the model"""
        try:
            # Preprocess inputs
            pixel_values, input_ids, attention_mask = self.preprocess_inputs(image)
            
            # Model prediction
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pixel_values=pixel_values
                )
                
                logits = outputs["logits"]
                forecast = outputs["forecast"]
                
                # Get classification results
                probabilities = torch.softmax(logits, dim=1)
                predicted_class = torch.argmax(logits, dim=1).item()
                confidence = probabilities[0][predicted_class].item()
                
                # Get price forecast
                predicted_price = forecast.squeeze().item()
                
                # Format results
                classification_result = f"**Prediction:** {self.class_labels[predicted_class]}\n"
                classification_result += f"**Confidence:** {confidence:.2%}\n\n"
                classification_result += "**Class Probabilities:**\n"
                for i, (label, prob) in enumerate(zip(self.class_labels, probabilities[0])):
                    classification_result += f"- {label}: {prob:.2%}\n"
                
                forecast_result = f"**Predicted Next Close Price:** ${predicted_price:.2f}"
                
                return classification_result, forecast_result
                
        except Exception as e:
            error_msg = f"Error during prediction: {str(e)}"
            return error_msg, error_msg

def create_demo():
    """Create and launch the Gradio demo"""
    demo_instance = CandleFusionDemo()
    
    # Create Gradio interface
    with gr.Blocks(title="CandleFusion - Candlestick Chart Analysis", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # ๐Ÿ•ฏ๏ธ CandleFusion Demo
        
        Upload a candlestick chart image to get:
        - **Market Direction Prediction** (Bullish/Bearish)
        - **Next Close Price Forecast**
        
        This model analyzes candlestick charts using BERT + ViT architecture.
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### ๐Ÿ“Š Input")
                
                image_input = gr.Image(
                    label="Candlestick Chart",
                    type="numpy",
                    height=300
                )
                
                predict_btn = gr.Button("๐Ÿ”ฎ Analyze Chart", variant="primary")
                
                gr.Markdown("""
                ### ๐Ÿ’ก Tips:
                - Upload clear candlestick chart images
                - Charts should show recent price action
                """)
            
            with gr.Column(scale=1):
                gr.Markdown("### ๐Ÿ“ˆ Results")
                
                classification_output = gr.Markdown(
                    value="Upload an image and click 'Analyze Chart' to see prediction"
                )
                
                forecast_output = gr.Markdown(
                    value=""
                )
        
        # Example section
        gr.Markdown("### ๐Ÿ“š Example")
        gr.Examples(
            examples=[
                ["examples/example_chart.png"],
                ["examples/example_chart2.png"]
            ],
            inputs=[image_input],
            label="Try these examples:"
        )
        
        # Connect the prediction function
        predict_btn.click(
            fn=demo_instance.predict,
            inputs=[image_input],
            outputs=[classification_output, forecast_output]
        )
        
        gr.Markdown("""
        ---
        **Note:** This is a demo model. For production trading decisions, always consult with financial professionals and use additional analysis tools.
        """)
    
    return demo

def main():
    """Main function to launch the demo"""
    try:
        demo = create_demo()
        # Launch with server_name for compatibility on HF Spaces
        demo.launch(server_name="0.0.0.0")
    except Exception as e:
        print(f"Failed to launch Gradio demo: {e}")
        # Fallback launch with minimal configuration
        demo = create_demo()
        demo.launch(server_name="0.0.0.0")

if __name__ == "__main__":
    main()