Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- gradio_demo.py +12 -21
- training/dataset.py +7 -0
- training/train.py +4 -0
gradio_demo.py
CHANGED
|
@@ -57,8 +57,8 @@ class CandleFusionDemo:
|
|
| 57 |
# Class labels
|
| 58 |
self.class_labels = ["Bearish", "Bullish"]
|
| 59 |
|
| 60 |
-
def preprocess_inputs(self, image
|
| 61 |
-
"""Preprocess image
|
| 62 |
# Process image
|
| 63 |
if image is None:
|
| 64 |
raise ValueError("Please upload a candlestick chart image")
|
|
@@ -67,9 +67,8 @@ class CandleFusionDemo:
|
|
| 67 |
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 68 |
pixel_values = image_inputs["pixel_values"].to(self.device)
|
| 69 |
|
| 70 |
-
# Process text
|
| 71 |
-
|
| 72 |
-
text = "Market analysis" # Default text if empty
|
| 73 |
|
| 74 |
text_inputs = self.tokenizer(
|
| 75 |
text,
|
|
@@ -84,11 +83,11 @@ class CandleFusionDemo:
|
|
| 84 |
return pixel_values, input_ids, attention_mask
|
| 85 |
|
| 86 |
@spaces.GPU
|
| 87 |
-
def predict(self, image
|
| 88 |
"""Make prediction using the model"""
|
| 89 |
try:
|
| 90 |
# Preprocess inputs
|
| 91 |
-
pixel_values, input_ids, attention_mask = self.preprocess_inputs(image
|
| 92 |
|
| 93 |
# Model prediction
|
| 94 |
with torch.no_grad():
|
|
@@ -133,11 +132,11 @@ def create_demo():
|
|
| 133 |
gr.Markdown("""
|
| 134 |
# ๐ฏ๏ธ CandleFusion Demo
|
| 135 |
|
| 136 |
-
Upload a candlestick chart image
|
| 137 |
- **Market Direction Prediction** (Bullish/Bearish)
|
| 138 |
- **Next Close Price Forecast**
|
| 139 |
|
| 140 |
-
This model
|
| 141 |
""")
|
| 142 |
|
| 143 |
with gr.Row():
|
|
@@ -150,19 +149,11 @@ def create_demo():
|
|
| 150 |
height=300
|
| 151 |
)
|
| 152 |
|
| 153 |
-
text_input = gr.Textbox(
|
| 154 |
-
label="Market Context",
|
| 155 |
-
placeholder="Enter market analysis, news, or context (e.g., 'Strong volume with positive earnings report')",
|
| 156 |
-
lines=3,
|
| 157 |
-
value="Technical analysis of price action"
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
predict_btn = gr.Button("๐ฎ Analyze Chart", variant="primary")
|
| 161 |
|
| 162 |
gr.Markdown("""
|
| 163 |
### ๐ก Tips:
|
| 164 |
- Upload clear candlestick chart images
|
| 165 |
-
- Provide relevant market context
|
| 166 |
- Charts should show recent price action
|
| 167 |
""")
|
| 168 |
|
|
@@ -181,17 +172,17 @@ def create_demo():
|
|
| 181 |
gr.Markdown("### ๐ Example")
|
| 182 |
gr.Examples(
|
| 183 |
examples=[
|
| 184 |
-
["examples/example_chart.png"
|
| 185 |
-
["examples/example_chart2.png"
|
| 186 |
],
|
| 187 |
-
inputs=[image_input
|
| 188 |
label="Try these examples:"
|
| 189 |
)
|
| 190 |
|
| 191 |
# Connect the prediction function
|
| 192 |
predict_btn.click(
|
| 193 |
fn=demo_instance.predict,
|
| 194 |
-
inputs=[image_input
|
| 195 |
outputs=[classification_output, forecast_output]
|
| 196 |
)
|
| 197 |
|
|
|
|
| 57 |
# Class labels
|
| 58 |
self.class_labels = ["Bearish", "Bullish"]
|
| 59 |
|
| 60 |
+
def preprocess_inputs(self, image):
|
| 61 |
+
"""Preprocess image input for the model"""
|
| 62 |
# Process image
|
| 63 |
if image is None:
|
| 64 |
raise ValueError("Please upload a candlestick chart image")
|
|
|
|
| 67 |
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 68 |
pixel_values = image_inputs["pixel_values"].to(self.device)
|
| 69 |
|
| 70 |
+
# Process text with default value
|
| 71 |
+
text = "Market analysis" # Default text
|
|
|
|
| 72 |
|
| 73 |
text_inputs = self.tokenizer(
|
| 74 |
text,
|
|
|
|
| 83 |
return pixel_values, input_ids, attention_mask
|
| 84 |
|
| 85 |
@spaces.GPU
|
| 86 |
+
def predict(self, image):
|
| 87 |
"""Make prediction using the model"""
|
| 88 |
try:
|
| 89 |
# Preprocess inputs
|
| 90 |
+
pixel_values, input_ids, attention_mask = self.preprocess_inputs(image)
|
| 91 |
|
| 92 |
# Model prediction
|
| 93 |
with torch.no_grad():
|
|
|
|
| 132 |
gr.Markdown("""
|
| 133 |
# ๐ฏ๏ธ CandleFusion Demo
|
| 134 |
|
| 135 |
+
Upload a candlestick chart image to get:
|
| 136 |
- **Market Direction Prediction** (Bullish/Bearish)
|
| 137 |
- **Next Close Price Forecast**
|
| 138 |
|
| 139 |
+
This model analyzes candlestick charts using BERT + ViT architecture.
|
| 140 |
""")
|
| 141 |
|
| 142 |
with gr.Row():
|
|
|
|
| 149 |
height=300
|
| 150 |
)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
predict_btn = gr.Button("๐ฎ Analyze Chart", variant="primary")
|
| 153 |
|
| 154 |
gr.Markdown("""
|
| 155 |
### ๐ก Tips:
|
| 156 |
- Upload clear candlestick chart images
|
|
|
|
| 157 |
- Charts should show recent price action
|
| 158 |
""")
|
| 159 |
|
|
|
|
| 172 |
gr.Markdown("### ๐ Example")
|
| 173 |
gr.Examples(
|
| 174 |
examples=[
|
| 175 |
+
["examples/example_chart.png"],
|
| 176 |
+
["examples/example_chart2.png"]
|
| 177 |
],
|
| 178 |
+
inputs=[image_input],
|
| 179 |
label="Try these examples:"
|
| 180 |
)
|
| 181 |
|
| 182 |
# Connect the prediction function
|
| 183 |
predict_btn.click(
|
| 184 |
fn=demo_instance.predict,
|
| 185 |
+
inputs=[image_input],
|
| 186 |
outputs=[classification_output, forecast_output]
|
| 187 |
)
|
| 188 |
|
training/dataset.py
CHANGED
|
@@ -16,6 +16,13 @@ class CandlestickDataset(Dataset):
|
|
| 16 |
image_size (int): Size to resize chart images to (default 224)
|
| 17 |
"""
|
| 18 |
self.data = pd.read_csv(csv_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 20 |
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 21 |
self.image_size = image_size
|
|
|
|
| 16 |
image_size (int): Size to resize chart images to (default 224)
|
| 17 |
"""
|
| 18 |
self.data = pd.read_csv(csv_path)
|
| 19 |
+
|
| 20 |
+
# Validate required columns
|
| 21 |
+
required_columns = ["image_path", "text", "label", "next_close"]
|
| 22 |
+
for col in required_columns:
|
| 23 |
+
if col not in self.data.columns:
|
| 24 |
+
raise ValueError(f"Missing required column: {col} in {csv_path}")
|
| 25 |
+
|
| 26 |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 27 |
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 28 |
self.image_size = image_size
|
training/train.py
CHANGED
|
@@ -102,6 +102,10 @@ tags:
|
|
| 102 |
|
| 103 |
A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
## Architecture Overview
|
| 106 |
|
| 107 |
### Core Components
|
|
|
|
| 102 |
|
| 103 |
A multimodal financial analysis model that combines textual market sentiment with visual candlestick patterns for enhanced trading signal prediction and price forecasting.
|
| 104 |
|
| 105 |
+
## Links
|
| 106 |
+
- ๐ **GitHub Repository**: https://github.com/tuankg1028/CandleFusion
|
| 107 |
+
- ๐ **Demo on Hugging Face Spaces**: https://huggingface.co/spaces/tuankg1028/candlefusion
|
| 108 |
+
|
| 109 |
## Architecture Overview
|
| 110 |
|
| 111 |
### Core Components
|