tuankg1028 commited on
Commit
fbcb44a
ยท
verified ยท
1 Parent(s): aeae322

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. gradio_demo.py +12 -21
  2. training/dataset.py +7 -0
  3. training/train.py +4 -0
gradio_demo.py CHANGED
@@ -57,8 +57,8 @@ class CandleFusionDemo:
57
  # Class labels
58
  self.class_labels = ["Bearish", "Bullish"]
59
 
60
- def preprocess_inputs(self, image, text):
61
- """Preprocess image and text inputs for the model"""
62
  # Process image
63
  if image is None:
64
  raise ValueError("Please upload a candlestick chart image")
@@ -67,9 +67,8 @@ class CandleFusionDemo:
67
  image_inputs = self.processor(images=image, return_tensors="pt")
68
  pixel_values = image_inputs["pixel_values"].to(self.device)
69
 
70
- # Process text
71
- if not text.strip():
72
- text = "Market analysis" # Default text if empty
73
 
74
  text_inputs = self.tokenizer(
75
  text,
@@ -84,11 +83,11 @@ class CandleFusionDemo:
84
  return pixel_values, input_ids, attention_mask
85
 
86
  @spaces.GPU
87
- def predict(self, image, text):
88
  """Make prediction using the model"""
89
  try:
90
  # Preprocess inputs
91
- pixel_values, input_ids, attention_mask = self.preprocess_inputs(image, text)
92
 
93
  # Model prediction
94
  with torch.no_grad():
@@ -133,11 +132,11 @@ def create_demo():
133
  gr.Markdown("""
134
  # ๐Ÿ•ฏ๏ธ CandleFusion Demo
135
 
136
- Upload a candlestick chart image and provide market context to get:
137
  - **Market Direction Prediction** (Bullish/Bearish)
138
  - **Next Close Price Forecast**
139
 
140
- This model combines visual analysis of candlestick charts with textual market context using BERT + ViT architecture.
141
  """)
142
 
143
  with gr.Row():
@@ -150,19 +149,11 @@ def create_demo():
150
  height=300
151
  )
152
 
153
- text_input = gr.Textbox(
154
- label="Market Context",
155
- placeholder="Enter market analysis, news, or context (e.g., 'Strong volume with positive earnings report')",
156
- lines=3,
157
- value="Technical analysis of price action"
158
- )
159
-
160
  predict_btn = gr.Button("๐Ÿ”ฎ Analyze Chart", variant="primary")
161
 
162
  gr.Markdown("""
163
  ### ๐Ÿ’ก Tips:
164
  - Upload clear candlestick chart images
165
- - Provide relevant market context
166
  - Charts should show recent price action
167
  """)
168
 
@@ -181,17 +172,17 @@ def create_demo():
181
  gr.Markdown("### ๐Ÿ“š Example")
182
  gr.Examples(
183
  examples=[
184
- ["examples/example_chart.png", "Strong bullish momentum with high volume"],
185
- ["examples/example_chart2.png", "Bearish reversal pattern forming"]
186
  ],
187
- inputs=[image_input, text_input],
188
  label="Try these examples:"
189
  )
190
 
191
  # Connect the prediction function
192
  predict_btn.click(
193
  fn=demo_instance.predict,
194
- inputs=[image_input, text_input],
195
  outputs=[classification_output, forecast_output]
196
  )
197
 
 
57
  # Class labels
58
  self.class_labels = ["Bearish", "Bullish"]
59
 
60
+ def preprocess_inputs(self, image):
61
+ """Preprocess image input for the model"""
62
  # Process image
63
  if image is None:
64
  raise ValueError("Please upload a candlestick chart image")
 
67
  image_inputs = self.processor(images=image, return_tensors="pt")
68
  pixel_values = image_inputs["pixel_values"].to(self.device)
69
 
70
+ # Process text with default value
71
+ text = "Market analysis" # Default text
 
72
 
73
  text_inputs = self.tokenizer(
74
  text,
 
83
  return pixel_values, input_ids, attention_mask
84
 
85
  @spaces.GPU
86
+ def predict(self, image):
87
  """Make prediction using the model"""
88
  try:
89
  # Preprocess inputs
90
+ pixel_values, input_ids, attention_mask = self.preprocess_inputs(image)
91
 
92
  # Model prediction
93
  with torch.no_grad():
 
132
  gr.Markdown("""
133
  # ๐Ÿ•ฏ๏ธ CandleFusion Demo
134
 
135
+ Upload a candlestick chart image to get:
136
  - **Market Direction Prediction** (Bullish/Bearish)
137
  - **Next Close Price Forecast**
138
 
139
+ This model analyzes candlestick charts using BERT + ViT architecture.
140
  """)
141
 
142
  with gr.Row():
 
149
  height=300
150
  )
151
 
 
 
 
 
 
 
 
152
  predict_btn = gr.Button("๐Ÿ”ฎ Analyze Chart", variant="primary")
153
 
154
  gr.Markdown("""
155
  ### ๐Ÿ’ก Tips:
156
  - Upload clear candlestick chart images
 
157
  - Charts should show recent price action
158
  """)
159
 
 
172
  gr.Markdown("### ๐Ÿ“š Example")
173
  gr.Examples(
174
  examples=[
175
+ ["examples/example_chart.png"],
176
+ ["examples/example_chart2.png"]
177
  ],
178
+ inputs=[image_input],
179
  label="Try these examples:"
180
  )
181
 
182
  # Connect the prediction function
183
  predict_btn.click(
184
  fn=demo_instance.predict,
185
+ inputs=[image_input],
186
  outputs=[classification_output, forecast_output]
187
  )
188
 
training/dataset.py CHANGED
@@ -16,6 +16,13 @@ class CandlestickDataset(Dataset):
16
  image_size (int): Size to resize chart images to (default 224)
17
  """
18
  self.data = pd.read_csv(csv_path)
 
 
 
 
 
 
 
19
  self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
20
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
21
  self.image_size = image_size
 
16
  image_size (int): Size to resize chart images to (default 224)
17
  """
18
  self.data = pd.read_csv(csv_path)
19
+
20
+ # Validate required columns
21
+ required_columns = ["image_path", "text", "label", "next_close"]
22
+ for col in required_columns:
23
+ if col not in self.data.columns:
24
+ raise ValueError(f"Missing required column: {col} in {csv_path}")
25
+
26
  self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
27
  self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
28
  self.image_size = image_size
training/train.py CHANGED
@@ -102,6 +102,10 @@ tags:
102
 
103
  A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
104
 
 
 
 
 
105
  ## Architecture Overview
106
 
107
  ### Core Components
 
102
 
103
  A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
104
 
105
+ ## Links
106
+ - ๐Ÿ”— **GitHub Repository**: https://github.com/tuankg1028/CandleFusion
107
+ - ๐Ÿš€ **Demo on Hugging Face Spaces**: https://huggingface.co/spaces/tuankg1028/candlefusion
108
+
109
  ## Architecture Overview
110
 
111
  ### Core Components