omniverse1 commited on
Commit
cf14392
·
verified ·
1 Parent(s): 9927daa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -476
app.py CHANGED
@@ -73,8 +73,11 @@ def create_chart_analysis(interval):
73
  font=dict(color='white')
74
  )
75
 
 
 
 
76
  # Generate predictions
77
- predictions = model_handler.predict(df, horizon=10)
78
  current_price = df['Close'].iloc[-1]
79
 
80
  # Get signal
@@ -101,25 +104,28 @@ def create_chart_analysis(interval):
101
 
102
  # Create prediction chart
103
  pred_fig = go.Figure()
104
- future_dates = pd.date_range(
105
- start=df.index[-1], periods=len(predictions), freq='D'
106
- )
107
 
108
- pred_fig.add_trace(go.Scatter(
109
- x=future_dates, y=predictions,
110
- mode='lines+markers',
111
- line=dict(color='#FFD700', width=3),
112
- marker=dict(size=6),
113
- name='Predictions'
114
- ))
115
-
116
- pred_fig.add_trace(go.Scatter(
117
- x=[df.index[-1], future_dates[0]],
118
- y=[current_price, predictions[0]],
119
- mode='lines',
120
- line=dict(color='rgba(255,215,0,0.5)', width=2, dash='dash'),
121
- showlegend=False
122
- ))
 
 
 
 
 
 
123
 
124
  pred_fig.update_layout(
125
  title='Price Prediction (Next 10 Periods)',
@@ -252,10 +258,10 @@ with gr.Blocks(
252
  label="Time Interval",
253
  info="Select analysis timeframe"
254
  )
255
- refresh_btn = gr.Button("🔄 Refresh Data", variant="primary")
256
 
257
  with gr.Tabs():
258
- with gr.TabItem("📊 Chart Analysis"):
259
  with gr.Row():
260
  chart_plot = gr.Plot(label="Price Chart")
261
  pred_plot = gr.Plot(label="Predictions")
@@ -263,12 +269,12 @@ with gr.Blocks(
263
  with gr.Row():
264
  metrics_output = gr.JSON(label="Trading Metrics")
265
 
266
- with gr.TabItem("📰 Sentiment Analysis"):
267
  with gr.Row():
268
  sentiment_gauge = gr.Plot(label="Sentiment Score")
269
  news_display = gr.HTML(label="Market News")
270
 
271
- with gr.TabItem("📈 Fundamentals"):
272
  with gr.Row():
273
  fundamentals_gauge = gr.Plot(label="Strength Index")
274
  fundamentals_table = gr.Dataframe(
@@ -311,456 +317,4 @@ if __name__ == "__main__":
311
  server_port=7860,
312
  share=False,
313
  show_api=True
314
- )
315
- data_processor.py
316
- ADDED
317
-
318
-
319
-
320
-
321
-
322
-
323
-
324
-
325
-
326
-
327
-
328
-
329
-
330
-
331
-
332
-
333
-
334
-
335
-
336
-
337
-
338
-
339
-
340
-
341
-
342
-
343
-
344
-
345
-
346
-
347
-
348
-
349
-
350
-
351
-
352
-
353
-
354
-
355
-
356
-
357
-
358
-
359
-
360
-
361
-
362
-
363
-
364
-
365
-
366
-
367
-
368
-
369
-
370
-
371
-
372
-
373
-
374
-
375
-
376
-
377
-
378
-
379
-
380
-
381
-
382
-
383
-
384
-
385
-
386
-
387
-
388
-
389
-
390
-
391
-
392
-
393
-
394
-
395
-
396
-
397
-
398
-
399
-
400
-
401
-
402
-
403
-
404
-
405
-
406
-
407
-
408
-
409
-
410
-
411
-
412
-
413
-
414
-
415
-
416
-
417
-
418
-
419
-
420
-
421
-
422
-
423
-
424
-
425
-
426
-
427
-
428
-
429
-
430
-
431
-
432
-
433
-
434
-
435
-
436
-
437
-
438
-
439
-
440
-
441
-
442
-
443
-
444
-
445
-
446
-
447
-
448
-
449
-
450
-
451
-
452
-
453
-
454
-
455
-
456
-
457
-
458
-
459
-
460
- import yfinance as yf
461
- import pandas as pd
462
- import numpy as np
463
- from datetime import datetime, timedelta
464
-
465
- class DataProcessor:
466
- def __init__(self):
467
- self.ticker = "GC=F"
468
- self.fundamentals_cache = {}
469
-
470
- def get_gold_data(self, interval="1d", period="max"):
471
- """Fetch gold futures data from Yahoo Finance"""
472
- try:
473
- # Map internal intervals to yfinance format
474
- interval_map = {
475
- "5m": "5m",
476
- "15m": "15m",
477
- "30m": "30m",
478
- "1h": "60m",
479
- "4h": "240m",
480
- "1d": "1d",
481
- "1wk": "1wk",
482
- "1mo": "1mo",
483
- "3mo": "3mo"
484
- }
485
-
486
- yf_interval = interval_map.get(interval, "1d")
487
-
488
- # Determine appropriate period based on interval
489
- if interval in ["5m", "15m", "30m", "1h", "4h"]:
490
- period = "60d" # Intraday data limited to 60 days
491
- elif interval in ["1d"]:
492
- period = "1y"
493
- elif interval in ["1wk"]:
494
- period = "2y"
495
- else:
496
- period = "max"
497
-
498
- ticker = yf.Ticker(self.ticker)
499
- df = ticker.history(interval=yf_interval, period=period)
500
-
501
- if df.empty:
502
- raise ValueError("No data retrieved from Yahoo Finance")
503
-
504
- # Ensure proper column names
505
- df.columns = [col.capitalize() for col in df.columns]
506
-
507
- return df
508
-
509
- except Exception as e:
510
- print(f"Error fetching data: {e}")
511
- return pd.DataFrame()
512
-
513
- def calculate_indicators(self, df):
514
- """Calculate technical indicators"""
515
- if df.empty:
516
- return df
517
-
518
- # Simple Moving Averages
519
- df['SMA_20'] = df['Close'].rolling(window=20).mean()
520
- df['SMA_50'] = df['Close'].rolling(window=50).mean()
521
-
522
- # Exponential Moving Averages
523
- df['EMA_12'] = df['Close'].ewm(span=12, adjust=False).mean()
524
- df['EMA_26'] = df['Close'].ewm(span=26, adjust=False).mean()
525
-
526
- # MACD
527
- df['MACD'] = df['EMA_12'] - df['EMA_26']
528
- df['MACD_signal'] = df['MACD'].ewm(span=9, adjust=False).mean()
529
- df['MACD_histogram'] = df['MACD'] - df['MACD_signal']
530
-
531
- # RSI
532
- delta = df['Close'].diff()
533
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
534
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
535
- rs = gain / loss
536
- df['RSI'] = 100 - (100 / (1 + rs))
537
-
538
- # Bollinger Bands
539
- df['BB_middle'] = df['Close'].rolling(window=20).mean()
540
- bb_std = df['Close'].rolling(window=20).std()
541
- df['BB_upper'] = df['BB_middle'] + (bb_std * 2)
542
- df['BB_lower'] = df['BB_middle'] - (bb_std * 2)
543
-
544
- # Average True Range (ATR)
545
- high_low = df['High'] - df['Low']
546
- high_close = np.abs(df['High'] - df['Close'].shift())
547
- low_close = np.abs(df['Low'] - df['Close'].shift())
548
- ranges = pd.concat([high_low, high_close, low_close], axis=1)
549
- true_range = ranges.max(axis=1)
550
- df['ATR'] = true_range.rolling(window=14).mean()
551
-
552
- # Volume indicators
553
- df['Volume_SMA'] = df['Volume'].rolling(window=20).mean()
554
- df['Volume_ratio'] = df['Volume'] / df['Volume_SMA']
555
-
556
- return df
557
-
558
- def get_fundamental_data(self):
559
- """Get fundamental gold market data"""
560
- try:
561
- ticker = yf.Ticker(self.ticker)
562
- info = ticker.info
563
-
564
- # Mock some gold-specific fundamentals as yfinance may not have all
565
- fundamentals = {
566
- "Gold Strength Index": round(np.random.uniform(30, 80), 1),
567
- "Dollar Index": round(np.random.uniform(90, 110), 1),
568
- "Real Interest Rate": f"{np.random.uniform(-2, 5):.2f}%",
569
- "Gold Volatility": f"{np.random.uniform(10, 40):.1f}%",
570
- "Commercial Hedgers (Net)": f"{np.random.uniform(-50000, 50000):,.0f}",
571
- "Managed Money (Net)": f"{np.random.uniform(-100000, 100000):,.0f}",
572
- "Market Sentiment": np.random.choice(["Bullish", "Neutral", "Bearish"]),
573
- "Central Bank Demand": np.random.choice(["High", "Medium", "Low"]),
574
- "Jewelry Demand Trend": np.random.choice(["Increasing", "Stable", "Decreasing"])
575
- }
576
-
577
- return fundamentals
578
-
579
- except Exception as e:
580
- print(f"Error fetching fundamentals: {e}")
581
- return {"Error": str(e)}
582
-
583
- def prepare_for_chronos(self, df, lookback=100):
584
- """Prepare data for Chronos model"""
585
- if df.empty or len(df) < lookback:
586
- return None
587
-
588
- # Use close prices and normalize
589
- prices = df['Close'].iloc[-lookback:].values
590
- prices = prices.astype(np.float32)
591
-
592
- # Normalize to help model performance
593
- mean = np.mean(prices)
594
- std = np.std(prices)
595
- normalized = (prices - mean) / (std + 1e-8)
596
-
597
- return {
598
- 'values': normalized,
599
- 'mean': mean,
600
- 'std': std,
601
- 'original': prices
602
- }
603
- model_handler.py
604
- ADDED
605
-
606
-
607
-
608
-
609
-
610
-
611
-
612
-
613
-
614
-
615
-
616
-
617
-
618
-
619
-
620
-
621
-
622
-
623
-
624
-
625
-
626
-
627
-
628
-
629
-
630
-
631
-
632
-
633
-
634
-
635
-
636
-
637
-
638
-
639
-
640
-
641
-
642
-
643
-
644
-
645
-
646
-
647
-
648
-
649
-
650
-
651
-
652
-
653
-
654
-
655
-
656
-
657
-
658
-
659
-
660
-
661
-
662
-
663
-
664
-
665
-
666
-
667
-
668
-
669
-
670
-
671
-
672
-
673
-
674
-
675
-
676
-
677
-
678
-
679
-
680
-
681
-
682
-
683
-
684
-
685
-
686
- import torch
687
- import numpy as np
688
- from transformers import AutoTokenizer, AutoConfig
689
- from huggingface_hub import hf_hub_download
690
- import json
691
- import os
692
-
693
- class ModelHandler:
694
- def __init__(self):
695
- self.model_name = "amazon/chronos-t5-small" # Using smaller model for CPU
696
- self.tokenizer = None
697
- self.model = None
698
- self.device = "cpu"
699
- self.load_model()
700
-
701
- def load_model(self):
702
- """Load Chronos model optimized for CPU"""
703
- try:
704
- print(f"Loading {self.model_name}...")
705
-
706
- # Download config
707
- config_path = hf_hub_download(
708
- repo_id=self.model_name,
709
- filename="config.json"
710
- )
711
-
712
- with open(config_path, 'r') as f:
713
- config = json.load(f)
714
-
715
- # Initialize tokenizer
716
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
717
-
718
- # For CPU optimization, use TorchScript if available
719
- model_path = hf_hub_download(
720
- repo_id=self.model_name,
721
- filename="model.safetensors"
722
- )
723
-
724
- # Load model state dict
725
- from safetensors.torch import load_file
726
- state_dict = load_file(model_path)
727
-
728
- # Create model from config (simplified for CPU)
729
- # In production, would load full model architecture
730
- print("Model loaded successfully (optimized for CPU)")
731
-
732
- except Exception as e:
733
- print(f"Error loading model: {e}")
734
- print("Using fallback prediction method")
735
- self.model = None
736
-
737
- def predict(self, data, horizon=10):
738
- """Generate predictions using Chronos or fallback"""
739
- try:
740
- if data is None or len(data['values']) < 20:
741
- return np.array([0] * horizon)
742
-
743
- if self.model is None:
744
- # Fallback: Use simple trend extrapolation for CPU efficiency
745
- values = data['original']
746
- recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
747
-
748
- predictions = []
749
- last_value = values[-1]
750
-
751
- for i in range(horizon):
752
- # Add trend with some noise
753
- next_value = last_value + recent_trend * (i + 1)
754
- # Add realistic market noise
755
- noise = np.random.normal(0, data['std'] * 0.1)
756
- predictions.append(next_value + noise)
757
-
758
- return np.array(predictions)
759
-
760
- # In production, would implement full Chronos inference
761
- # For now, return fallback
762
- return self.predict(data, horizon) # Recursive call to fallback
763
-
764
- except Exception as e:
765
- print(f"Prediction error: {e}")
766
- return np.array([0] * horizon)
 
73
  font=dict(color='white')
74
  )
75
 
76
+ # FIX: Prepare data for Chronos before passing to model_handler.predict
77
+ prepared_data = data_processor.prepare_for_chronos(df)
78
+
79
  # Generate predictions
80
+ predictions = model_handler.predict(prepared_data, horizon=10)
81
  current_price = df['Close'].iloc[-1]
82
 
83
  # Get signal
 
104
 
105
  # Create prediction chart
106
  pred_fig = go.Figure()
 
 
 
107
 
108
+ # Check if predictions are valid before plotting
109
+ if predictions.any():
110
+ future_dates = pd.date_range(
111
+ start=df.index[-1], periods=len(predictions), freq='D'
112
+ )
113
+
114
+ pred_fig.add_trace(go.Scatter(
115
+ x=future_dates, y=predictions,
116
+ mode='lines+markers',
117
+ line=dict(color='#FFD700', width=3),
118
+ marker=dict(size=6),
119
+ name='Predictions'
120
+ ))
121
+
122
+ pred_fig.add_trace(go.Scatter(
123
+ x=[df.index[-1], future_dates[0]],
124
+ y=[current_price, predictions[0]],
125
+ mode='lines',
126
+ line=dict(color='rgba(255,215,0,0.5)', width=2, dash='dash'),
127
+ showlegend=False
128
+ ))
129
 
130
  pred_fig.update_layout(
131
  title='Price Prediction (Next 10 Periods)',
 
258
  label="Time Interval",
259
  info="Select analysis timeframe"
260
  )
261
+ refresh_btn = gr.Button("売 Refresh Data", variant="primary")
262
 
263
  with gr.Tabs():
264
+ with gr.TabItem("投 Chart Analysis"):
265
  with gr.Row():
266
  chart_plot = gr.Plot(label="Price Chart")
267
  pred_plot = gr.Plot(label="Predictions")
 
269
  with gr.Row():
270
  metrics_output = gr.JSON(label="Trading Metrics")
271
 
272
+ with gr.TabItem("堂 Sentiment Analysis"):
273
  with gr.Row():
274
  sentiment_gauge = gr.Plot(label="Sentiment Score")
275
  news_display = gr.HTML(label="Market News")
276
 
277
+ with gr.TabItem("嶋 Fundamentals"):
278
  with gr.Row():
279
  fundamentals_gauge = gr.Plot(label="Strength Index")
280
  fundamentals_table = gr.Dataframe(
 
317
  server_port=7860,
318
  share=False,
319
  show_api=True
320
+ )