File size: 30,473 Bytes
bbc558b
ca24e8f
fbe9e4a
7fdad36
42f56cc
 
af6584b
5083e17
 
8d18142
fb8592f
5083e17
8295760
5083e17
4ed264b
4c6ebab
 
 
 
 
 
 
 
 
 
8295760
 
 
5bbf055
797be6b
8295760
4c6ebab
8295760
4c6ebab
bbc558b
 
8295760
 
4c6ebab
8295760
 
5083e17
 
8295760
 
 
 
 
 
5083e17
 
 
 
 
bbc558b
 
5083e17
 
 
53f36cd
 
 
cd0b356
bbc558b
 
8295760
bbc558b
8295760
bbc558b
 
 
 
 
 
 
 
 
5bbf055
bbc558b
 
 
4c6ebab
8295760
 
bbc558b
 
 
 
 
 
 
8295760
bbc558b
 
8d18142
 
cd0b356
797be6b
8295760
bbc558b
 
8295760
 
 
 
bbc558b
 
 
 
8295760
bbc558b
 
 
 
 
 
 
8295760
 
 
 
fbe9e4a
8295760
4c6ebab
8295760
4ed264b
4c6ebab
 
fbe9e4a
 
4c6ebab
fbe9e4a
4c6ebab
 
bbc558b
5083e17
4c6ebab
9c2fb56
fbe9e4a
8d18142
 
8295760
4c6ebab
8295760
9c2fb56
4c6ebab
bbc558b
 
 
8295760
 
 
5083e17
4c6ebab
53f36cd
bbc558b
 
53f36cd
5083e17
4c6ebab
b746adc
fbe9e4a
 
4c6ebab
fbe9e4a
ca24e8f
4ed264b
 
 
fbe9e4a
bbc558b
797be6b
5083e17
 
 
 
bbc558b
5083e17
 
 
 
 
53f36cd
bbc558b
 
53f36cd
5083e17
8295760
bbc558b
efa2e5a
eee101a
797be6b
42f56cc
bbc558b
 
fbe9e4a
efa2e5a
fbe9e4a
42f56cc
bbc558b
 
42f56cc
efa2e5a
fbe9e4a
cbb9529
4c6ebab
 
bbc558b
 
 
 
 
 
 
 
 
 
4c6ebab
 
bbc558b
eee101a
bbc558b
 
 
4c6ebab
 
bbc558b
 
4c6ebab
 
bbc558b
 
4c6ebab
 
 
bbc558b
 
 
 
 
 
 
 
 
 
5083e17
 
4c6ebab
 
bbc558b
 
 
 
 
 
8295760
ca24e8f
4c6ebab
 
bbc558b
797be6b
65a3db4
4c6ebab
bbc558b
4c6ebab
8295760
 
bbc558b
 
 
8d18142
7785336
4c6ebab
8295760
 
eee101a
7fdad36
8295760
 
4c6ebab
65a3db4
bbc558b
 
 
 
8295760
5bbf055
4c6ebab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65a3db4
8295760
4c6ebab
 
 
 
 
 
 
4ed264b
65a3db4
4c6ebab
 
 
 
 
 
 
4ed264b
 
4c6ebab
 
 
 
 
 
 
 
 
cbb9529
4ed264b
4c6ebab
cbb9529
4c6ebab
 
 
 
 
 
 
 
 
 
 
 
 
ca24e8f
 
 
 
 
 
c7ea09b
ca24e8f
 
 
 
4c6ebab
 
8295760
0dfa99d
bbc558b
 
 
 
 
fbe9e4a
8295760
bbc558b
 
 
8295760
bbc558b
 
 
fbe9e4a
064bf5c
 
653d088
04c51ad
cbb9529
 
 
 
 
 
064bf5c
 
04c51ad
cbb9529
 
 
 
064bf5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653d088
3cb3ebf
653d088
 
064bf5c
 
 
 
 
 
 
 
 
 
cbb9529
 
 
 
 
 
 
 
 
 
 
 
 
 
4c6ebab
 
 
 
 
 
 
 
 
 
 
 
cbb9529
bbc558b
 
e056211
3cb3ebf
bbc558b
cbb9529
 
 
 
 
 
 
 
 
e056211
cbb9529
bbc558b
 
 
 
 
8295760
65a3db4
 
8295760
 
56a394e
8295760
bbc558b
cbb9529
 
 
 
 
 
 
04c51ad
cbb9529
 
 
c7ea09b
cbb9529
8295760
bbc558b
 
cbb9529
 
 
 
 
 
04c51ad
cbb9529
 
 
c7ea09b
cbb9529
4c6ebab
 
ca24e8f
8295760
4c6ebab
 
8295760
5bbf055
cbb9529
8295760
bbc558b
addb902
65a3db4
bbc558b
cbb9529
 
 
 
 
 
04c51ad
cbb9529
 
 
c7ea09b
cbb9529
 
bbc558b
65a3db4
cbb9529
bbc558b
8295760
653d088
 
 
cbb9529
4ed264b
 
8295760
bbc558b
 
 
 
 
cbb9529
4ed264b
4c6ebab
 
 
 
54bd91f
4c6ebab
 
 
 
 
 
 
 
 
 
3cb3ebf
4c6ebab
 
 
3cb3ebf
bbc558b
 
 
 
 
 
4c6ebab
bbc558b
 
 
8295760
4c6ebab
 
 
 
 
 
 
8295760
 
 
 
 
 
eee101a
8295760
54bd91f
eee101a
8295760
 
3f96b28
 
 
 
cbb9529
 
8295760
e056211
cbb9529
 
04c51ad
 
cbb9529
 
 
bbc558b
 
eee101a
cbb9529
 
04c51ad
cbb9529
 
 
 
04c51ad
cbb9529
 
 
 
04c51ad
 
cbb9529
04c51ad
 
 
cbb9529
04c51ad
cbb9529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04c51ad
 
cbb9529
04c51ad
 
 
cbb9529
 
 
04c51ad
cbb9529
04c51ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653d088
cbb9529
 
653d088
 
cbb9529
 
0dfa99d
cbb9529
c83ef19
cbb9529
 
04c51ad
cbb9529
 
04c51ad
cbb9529
 
 
 
 
 
 
 
 
b88f474
 
cbb9529
4c6ebab
cbb9529
 
 
4c6ebab
 
cbb9529
 
 
4c6ebab
 
cbb9529
 
 
4c6ebab
8295760
56a394e
8295760
 
42f56cc
db76cb7
04c51ad
cbb9529
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
# app.py
import os, io, math, time, warnings
warnings.filterwarnings("ignore")

from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import requests
import yfinance as yf
import gradio as gr

# ---- runtime niceties ----
os.environ.setdefault("MPLCONFIGDIR", os.getenv("MPLCONFIGDIR", "/home/user/.config/matplotlib"))
os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True)
for d in [
    "/home/user/.cache",
    "/home/user/.cache/huggingface",
    "/home/user/.cache/huggingface/hub",
    "/home/user/.cache/sentencetransformers",
]:
    os.makedirs(d, exist_ok=True)

# ---------------- config ----------------
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

MAX_TICKERS = 30
DEFAULT_LOOKBACK_YEARS = 10
MARKET_TICKER = "VOO"

SYNTH_ROWS = 1000  # synthetic candidate portfolios per compute

# Globals that update with horizon changes
HORIZON_YEARS = 10
RF_CODE = "DGS10"
RF_ANN = 0.0375  # refreshed at launch

# ---------------- helpers ----------------
def fred_series_for_horizon(years: float) -> str:
    y = max(1.0, min(100.0, float(years)))
    if y <= 2: return "DGS2"
    if y <= 3: return "DGS3"
    if y <= 5: return "DGS5"
    if y <= 7: return "DGS7"
    if y <= 10: return "DGS10"
    if y <= 20: return "DGS20"
    return "DGS30"

def fetch_fred_yield_annual(code: str) -> float:
    url = f"https://fred.stlouisfed.org/graph/fredgraph.csv?id={code}"
    try:
        r = requests.get(url, timeout=10)
        r.raise_for_status()
        df = pd.read_csv(io.StringIO(r.text))
        s = pd.to_numeric(df.iloc[:, 1], errors="coerce").dropna()
        return float(s.iloc[-1] / 100.0) if len(s) else 0.03
    except Exception:
        return 0.03

def fetch_prices_monthly(tickers: List[str], years: int) -> pd.DataFrame:
    tickers = list(dict.fromkeys([t.upper().strip() for t in tickers]))
    start = (pd.Timestamp.today(tz="UTC") - pd.DateOffset(years=years, days=7)).date()
    end = pd.Timestamp.today(tz="UTC").date()

    df = yf.download(
        tickers,
        start=start,
        end=end,
        interval="1mo",
        auto_adjust=True,
        actions=False,
        progress=False,
        group_by="column",
        threads=False,
    )

    if isinstance(df, pd.Series):
        df = df.to_frame()

    if isinstance(df.columns, pd.MultiIndex):
        lvl0 = [str(x) for x in df.columns.get_level_values(0).unique()]
        if "Close" in lvl0:
            df = df["Close"]
        elif "Adj Close" in lvl0:
            df = df["Adj Close"]
        else:
            df = df.xs(df.columns.levels[0][-1], axis=1, level=0, drop_level=True)

    cols = [c for c in tickers if c in df.columns]
    out = df[cols].dropna(how="all").fillna(method="ffill")
    return out

def monthly_returns(prices: pd.DataFrame) -> pd.DataFrame:
    return prices.pct_change().dropna()

def yahoo_search(query: str):
    if not query or not str(query).strip():
        return []
    url = "https://query1.finance.yahoo.com/v1/finance/search"
    params = {"q": query.strip(), "quotesCount": 10, "newsCount": 0}
    headers = {"User-Agent": "Mozilla/5.0"}
    try:
        r = requests.get(url, params=params, headers=headers, timeout=10)
        r.raise_for_status()
        data = r.json()
        out = []
        for q in data.get("quotes", []):
            sym = q.get("symbol")
            name = q.get("shortname") or q.get("longname") or ""
            exch = q.get("exchDisp") or ""
            if sym and sym.isascii():
                out.append(f"{sym}  |  {name}  |  {exch}")
        if not out:
            out = [f"{query.strip().upper()}  |  typed symbol  |  n/a"]
        return out[:10]
    except Exception:
        return [f"{query.strip().upper()}  |  typed symbol  |  n/a"]

def validate_tickers(symbols: List[str], years: int) -> List[str]:
    base = [s for s in dict.fromkeys([t.upper().strip() for t in symbols]) if s]
    px = fetch_prices_monthly(base + [MARKET_TICKER], years)
    ok = [s for s in base if s in px.columns]
    # require market proxy to compute CAPM
    if MARKET_TICKER not in px.columns:
        return []
    return ok

# -------------- aligned moments --------------
def get_aligned_monthly_returns(symbols: List[str], years: int) -> pd.DataFrame:
    uniq = [c for c in dict.fromkeys(symbols) if c != MARKET_TICKER]
    tickers = uniq + [MARKET_TICKER]
    px = fetch_prices_monthly(tickers, years)
    rets = monthly_returns(px)
    cols = [c for c in uniq if c in rets.columns] + ([MARKET_TICKER] if MARKET_TICKER in rets.columns else [])
    R = rets[cols].dropna(how="any")
    return R.loc[:, ~R.columns.duplicated()]

def estimate_all_moments_aligned(symbols: List[str], years: int, rf_ann: float):
    R = get_aligned_monthly_returns(symbols, years)
    if MARKET_TICKER not in R.columns or len(R) < 3:
        raise ValueError("Not enough aligned data with market proxy.")

    m = R[MARKET_TICKER]
    if isinstance(m, pd.DataFrame):
        m = m.iloc[:, 0].squeeze()

    mu_m_ann = float(m.mean() * 12.0)
    sigma_m_ann = float(m.std(ddof=1) * math.sqrt(12.0))
    erp_ann = float(mu_m_ann - rf_ann)

    rf_m = rf_ann / 12.0
    ex_m = m - rf_m
    var_m = float(np.var(ex_m.values, ddof=1))
    var_m = max(var_m, 1e-9)

    betas: Dict[str, float] = {}
    for s in [c for c in R.columns if c != MARKET_TICKER]:
        ex_s = R[s] - rf_m
        cov_sm = float(np.cov(ex_s.values, ex_m.values, ddof=1)[0, 1])
        betas[s] = cov_sm / var_m
    betas[MARKET_TICKER] = 1.0

    # include market in covariance so σ for portfolios holding VOO is correct
    asset_cols_all = list(R.columns)  # includes market
    cov_m_all = np.cov(R[asset_cols_all].values.T, ddof=1) if asset_cols_all else np.zeros((0, 0))
    covA = pd.DataFrame(cov_m_all * 12.0, index=asset_cols_all, columns=asset_cols_all)

    return {"betas": betas, "cov_ann": covA, "erp_ann": erp_ann, "sigma_m_ann": sigma_m_ann}

def capm_er(beta: float, rf_ann: float, erp_ann: float) -> float:
    return float(rf_ann + beta * erp_ann)

def portfolio_stats(weights: Dict[str, float],
                    cov_ann: pd.DataFrame,
                    betas: Dict[str, float],
                    rf_ann: float,
                    erp_ann: float) -> Tuple[float, float, float]:
    tickers = list(weights.keys())
    w = np.array([weights[t] for t in tickers], dtype=float)
    gross = float(np.sum(np.abs(w)))
    if gross <= 1e-12:
        return 0.0, rf_ann, 0.0
    w_expo = w / gross
    beta_p = float(np.dot([betas.get(t, 0.0) for t in tickers], w_expo))
    mu_capm = capm_er(beta_p, rf_ann, erp_ann)
    cov = cov_ann.reindex(index=tickers, columns=tickers).fillna(0.0).to_numpy()
    sigma_hist = float(max(w_expo.T @ cov @ w_expo, 0.0)) ** 0.5
    return beta_p, mu_capm, sigma_hist

def efficient_same_sigma(sigma_target: float, rf_ann: float, erp_ann: float, sigma_mkt: float):
    if sigma_mkt <= 1e-12:
        return 0.0, 1.0, rf_ann
    a = sigma_target / sigma_mkt
    return a, 1.0 - a, rf_ann + a * erp_ann

def efficient_same_return(mu_target: float, rf_ann: float, erp_ann: float, sigma_mkt: float):
    if abs(erp_ann) <= 1e-12:
        return 0.0, 1.0, rf_ann
    a = (mu_target - rf_ann) / erp_ann
    return a, 1.0 - a, abs(a) * sigma_mkt

# -------------- plotting --------------
def _pct(x):
    return np.asarray(x, dtype=float) * 100.0

def plot_cml(rf_ann, erp_ann, sigma_mkt,
             sigma_hist_p, mu_capm_p,
             same_sigma_mu, same_mu_sigma,
             sugg_sigma_hist=None, sugg_mu_capm=None) -> Image.Image:

    fig = plt.figure(figsize=(6.5, 4.3), dpi=120)

    xmax = max(0.3, sigma_mkt * 2.4, (sigma_hist_p or 0.0) * 1.6, (sugg_sigma_hist or 0.0) * 1.6)
    xs = np.linspace(0, xmax, 200)
    slope = erp_ann / max(sigma_mkt, 1e-9)
    cml = rf_ann + slope * xs

    plt.plot(_pct(xs), _pct(cml), label="CML (Market/Bills)", linewidth=1.8)
    plt.scatter([_pct(0)], [_pct(rf_ann)], label="Risk-free")
    plt.scatter([_pct(sigma_mkt)], [_pct(rf_ann + erp_ann)], label="Market")

    y_cml_at_sigma_p = rf_ann + slope * max(0.0, float(sigma_hist_p))
    y_you = min(float(mu_capm_p), y_cml_at_sigma_p)
    plt.scatter([_pct(sigma_hist_p)], [_pct(y_you)], label="Your CAPM point")

    plt.scatter([_pct(sigma_hist_p)], [_pct(same_sigma_mu)], marker="^", label="Efficient (same σ)")
    plt.scatter([_pct(same_mu_sigma)], [_pct(mu_capm_p)], marker="^", label="Efficient (same E[r])")

    if sugg_sigma_hist is not None and sugg_mu_capm is not None:
        y_cml_at_sugg = rf_ann + slope * max(0.0, float(sugg_sigma_hist))
        y_sugg = min(float(sugg_mu_capm), y_cml_at_sugg)
        plt.scatter([_pct(sugg_sigma_hist)], [_pct(y_sugg)], label="Selected Suggestion", marker="X", s=60)

    plt.xlabel("σ (historical, annualized, %)")
    plt.ylabel("CAPM E[r] (annual, %)")
    plt.legend(loc="best", fontsize=8)
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

# -------------- synthetic dataset & suggestions --------------
def build_synthetic_dataset(universe_user: List[str],
                            covA: pd.DataFrame,
                            betas: Dict[str, float],
                            rf_ann: float,
                            erp_ann: float,
                            sigma_mkt: float,
                            n_rows: int = SYNTH_ROWS) -> pd.DataFrame:
    rng = np.random.default_rng(12345)
    assets = list(universe_user)
    if len(assets) == 0:
        return pd.DataFrame(columns=["tickers", "weights", "beta", "mu_capm", "sigma_hist"])

    rows = []
    for _ in range(n_rows):
        k = int(rng.integers(low=1, high=min(8, len(assets)) + 1))
        picks = list(rng.choice(assets, size=k, replace=False))
        w = rng.dirichlet(np.ones(k))
        beta_p = float(np.dot([betas.get(t, 0.0) for t in picks], w))
        mu_capm = capm_er(beta_p, rf_ann, erp_ann)
        sub = covA.reindex(index=picks, columns=picks).fillna(0.0).to_numpy()
        sigma_hist = float(max(w.T @ sub @ w, 0.0)) ** 0.5

        rows.append({
            "tickers": ",".join(picks),
            "weights": ",".join(f"{x:.6f}" for x in w),
            "beta": beta_p,
            "mu_capm": mu_capm,
            "sigma_hist": sigma_hist
        })
    return pd.DataFrame(rows)

def _band_bounds(sigma_mkt: float, band: str) -> Tuple[float, float]:
    band = (band or "Medium").strip().lower()
    if band.startswith("low"):
        return 0.0, 0.8 * sigma_mkt
    if band.startswith("high"):
        return 1.2 * sigma_mkt, 3.0 * sigma_mkt
    return 0.8 * sigma_mkt, 1.2 * sigma_mkt

def _exposure_vec(row: pd.Series, universe: List[str]) -> np.ndarray:
    vec = np.zeros(len(universe))
    idx_map = {t: i for i, t in enumerate(universe)}
    ts = [t.strip() for t in str(row["tickers"]).split(",") if t.strip()]
    ws = [float(x) for x in str(row["weights"]).split(",")]
    s = sum(ws) or 1.0
    ws = [max(0.0, w) / s for w in ws]
    for t, w in zip(ts, ws):
        if t in idx_map:
            vec[idx_map[t]] = w
    return vec

def rerank_and_pick_one(df_band: pd.DataFrame,
                        universe: List[str],
                        desired_band: str,
                        alpha: float = 0.6) -> pd.Series:
    if df_band.empty:
        return pd.Series(dtype=object)

    exp_target = np.ones(len(universe))
    exp_target = exp_target / np.sum(exp_target)

    embs_ok = True
    try:
        from sentence_transformers import SentenceTransformer
        model = SentenceTransformer("FinLang/finance-embeddings-investopedia")
        prompt_map = {
            "low": "low risk conservative diversified stable portfolio",
            "medium": "balanced medium risk diversified portfolio",
            "high": "high risk growth aggressive portfolio higher expected return",
        }
        prompt = prompt_map.get(desired_band.lower(), prompt_map["medium"])
        q = model.encode([prompt])
    except Exception:
        embs_ok = False
        q = None

    def _cos(a, b):
        an = np.linalg.norm(a) + 1e-12
        bn = np.linalg.norm(b) + 1e-12
        return float(np.dot(a, b) / (an * bn))

    X_exp = np.stack([_exposure_vec(r, universe) for _, r in df_band.iterrows()], axis=0)
    exp_sims = np.array([_cos(x, exp_target) for x in X_exp])

    if embs_ok:
        cand_texts = []
        for _, r in df_band.iterrows():
            cand_texts.append(
                f"portfolio with tickers {r['tickers']} having beta {float(r['beta']):.2f}, "
                f"expected return {float(r['mu_capm']):.3f}, sigma {float(r['sigma_hist']):.3f}"
            )
        from numpy.linalg import norm
        C = model.encode(cand_texts)
        qv = q.reshape(-1)
        coss = (C @ qv) / (norm(C, axis=1) * (norm(qv) + 1e-12))
        coss = np.nan_to_num(coss, nan=0.0)
    else:
        coss = np.zeros(len(df_band))

    base = alpha * exp_sims + (1 - alpha) * coss
    order = np.argsort(-base)
    best_idx = int(order[0])
    return df_band.iloc[best_idx]

def suggest_one_per_band(synth: pd.DataFrame, sigma_mkt: float, universe_user: List[str]) -> Dict[str, pd.Series]:
    out: Dict[str, pd.Series] = {}
    for band in ["Low", "Medium", "High"]:
        lo, hi = _band_bounds(sigma_mkt, band)
        pool = synth[(synth["sigma_hist"] >= lo) & (synth["sigma_hist"] <= hi)].copy()
        if pool.empty:
            if band.lower() == "low":
                pool = synth.nsmallest(50, "sigma_hist").copy()
            elif band.lower() == "high":
                pool = synth.nlargest(50, "sigma_hist").copy()
            else:
                tmp = synth.copy()
                tmp["dist_med"] = (tmp["sigma_hist"] - sigma_mkt).abs()
                pool = tmp.nsmallest(100, "dist_med").drop(columns=["dist_med"])
        chosen = rerank_and_pick_one(pool, universe_user, band)
        out[band.lower()] = chosen
    return out

# -------------- UI helpers --------------
def empty_positions_df():
    return pd.DataFrame(columns=["ticker", "amount_usd", "weight_exposure", "beta"])

def empty_suggestion_df():
    return pd.DataFrame(columns=["ticker", "weight_%", "amount_$"])

def set_horizon(years: float):
    y = max(1.0, min(100.0, float(years)))
    code = fred_series_for_horizon(y)
    rf = fetch_fred_yield_annual(code)
    global HORIZON_YEARS, RF_CODE, RF_ANN
    HORIZON_YEARS = y
    RF_CODE = code
    RF_ANN = rf

def search_tickers_cb(q: str):
    opts = yahoo_search(q)
    if not opts:
        opts = ["No matches found"]
    # Pre-select the first result and put helper text into the box
    return gr.update(
        choices=opts,
        value=opts[0],
        info="Select a symbol and click 'Add selected to portfolio'."
    )

def add_symbol(selection: str, table: Optional[pd.DataFrame]):
    if (not selection) or ("No matches" in selection) or ("Select a symbol" in selection) or ("type above" in selection):
        return (
            table if isinstance(table, pd.DataFrame) else pd.DataFrame(columns=["ticker","amount_usd"]),
            "Pick a valid match first."
        )
    symbol = selection.split("|")[0].strip().upper()

    current = []
    if isinstance(table, pd.DataFrame) and not table.empty:
        current = [str(x).upper() for x in table["ticker"].tolist() if str(x) != "nan"]
    tickers = current if symbol in current else current + [symbol]

    val = validate_tickers(tickers, years=DEFAULT_LOOKBACK_YEARS)
    tickers = [t for t in tickers if t in val]

    amt_map = {}
    if isinstance(table, pd.DataFrame) and not table.empty:
        for _, r in table.iterrows():
            t = str(r.get("ticker", "")).upper()
            if t in tickers:
                amt_map[t] = float(pd.to_numeric(r.get("amount_usd", 0.0), errors="coerce") or 0.0)

    new_table = pd.DataFrame({"ticker": tickers, "amount_usd": [amt_map.get(t, 0.0) for t in tickers]})
    if len(new_table) > MAX_TICKERS:
        new_table = new_table.iloc[:MAX_TICKERS]
        return new_table, f"Reached max of {MAX_TICKERS}."
    return new_table, f"Added {symbol}."

def add_symbol_table_only(selection: str, table: Optional[pd.DataFrame]):
    new_table, _msg = add_symbol(selection, table)
    return new_table

def lock_ticker_column(tb: Optional[pd.DataFrame]):
    if not isinstance(tb, pd.DataFrame) or tb.empty:
        return pd.DataFrame(columns=["ticker", "amount_usd"])
    tickers = [str(x).upper() for x in tb["ticker"].tolist()]
    amounts = pd.to_numeric(tb["amount_usd"], errors="coerce").fillna(0.0).tolist()
    val = validate_tickers(tickers, years=DEFAULT_LOOKBACK_YEARS)
    tickers = [t for t in tickers if t in val]
    amounts = amounts[:len(tickers)] + [0.0] * max(0, len(tickers) - len(amounts))
    return pd.DataFrame({"ticker": tickers, "amount_usd": amounts})

def current_ticker_choices(tb: Optional[pd.DataFrame]):
    if not isinstance(tb, pd.DataFrame) or tb.empty:
        return gr.update(choices=[], value=None)
    tickers = [str(x).upper() for x in tb["ticker"].tolist() if str(x) != "nan"]
    return gr.update(choices=tickers, value=None)

def remove_selected_ticker(symbol: Optional[str], table: Optional[pd.DataFrame]):
    if not isinstance(table, pd.DataFrame) or table.empty or not symbol:
        # nothing to do
        return table if isinstance(table, pd.DataFrame) else pd.DataFrame(columns=["ticker", "amount_usd"]), gr.update()
    out = table[table["ticker"].str.upper() != symbol.upper()].copy()
    return out, current_ticker_choices(out)

# -------------- main compute (STREAMING to show progress) --------------
UNIVERSE: List[str] = [MARKET_TICKER, "QQQ", "VTI", "SOXX", "IBIT"]

def _holdings_table_from_row(row: pd.Series, budget: float) -> pd.DataFrame:
    ts = [t.strip() for t in str(row["tickers"]).split(",") if t.strip()]
    ws = [float(x) for x in str(row["weights"]).split(",")]
    s = sum(ws) if ws else 1.0
    ws = [max(0.0, w) / s for w in ws]
    return pd.DataFrame(
        [{"ticker": t, "weight_%": round(w*100.0, 2), "amount_$": round(w*budget, 0)} for t, w in zip(ts, ws)],
        columns=["ticker", "weight_%", "amount_$"]
    )

def compute_stream(
    years_lookback: int,
    table: Optional[pd.DataFrame],
    pick_band_to_show: str,  # "Low" | "Medium" | "High"
    progress=gr.Progress(track_tqdm=True),
):
    # Yield 0: show loading banner, keep right panel hidden
    loading_banner = "**🔄 Computations running…** This can take a moment."
    yield (
        None, "", empty_positions_df(), empty_suggestion_df(), None,
        "", "", "",
        gr.update(visible=False),  # right_col
        gr.update(visible=False),  # sugg_row
        gr.update(value=loading_banner, visible=True)  # status_md
    )

    progress(0.05, desc="Validating inputs…")
    # sanitize table
    if isinstance(table, pd.DataFrame):
        df = table.copy()
    else:
        df = pd.DataFrame(columns=["ticker", "amount_usd"])
    df = df.dropna(how="all")
    if "ticker" not in df.columns: df["ticker"] = []
    if "amount_usd" not in df.columns: df["amount_usd"] = []
    df["ticker"] = df["ticker"].astype(str).str.upper().str.strip()
    df["amount_usd"] = pd.to_numeric(df["amount_usd"], errors="coerce").fillna(0.0)

    symbols = [t for t in df["ticker"].tolist() if t]
    if len(symbols) == 0:
        # final yield with message; keep right panel hidden
        yield (
            None,
            "Add at least one ticker.",
            empty_positions_df(),
            empty_suggestion_df(),
            None,
            "", "", "",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(value="", visible=False)
        )
        return

    symbols = validate_tickers(symbols, years_lookback)
    if len(symbols) == 0:
        yield (
            None,
            "Could not validate any tickers.",
            empty_positions_df(),
            empty_suggestion_df(),
            None,
            "", "", "",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(value="", visible=False)
        )
        return

    global UNIVERSE
    UNIVERSE = list(sorted(set(symbols)))[:MAX_TICKERS]

    df = df[df["ticker"].isin(symbols)].copy()
    amounts = {r["ticker"]: float(r["amount_usd"]) for _, r in df.iterrows()}
    rf_ann = RF_ANN

    progress(0.25, desc="Estimating betas & covariances…")
    moms = estimate_all_moments_aligned(symbols, years_lookback, rf_ann)
    betas, covA, erp_ann, sigma_mkt = moms["betas"], moms["cov_ann"], moms["erp_ann"], moms["sigma_m_ann"]

    gross = sum(abs(v) for v in amounts.values())
    if gross <= 1e-12:
        yield (
            None,
            "All amounts are zero.",
            empty_positions_df(),
            empty_suggestion_df(),
            None,
            "", "", "",
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(value="", visible=False)
        )
        return

    weights = {k: v / gross for k, v in amounts.items()}

    progress(0.45, desc="Computing portfolio statistics…")
    beta_p, mu_capm, sigma_hist = portfolio_stats(weights, covA, betas, rf_ann, erp_ann)

    a_sigma, b_sigma, mu_eff_same_sigma = efficient_same_sigma(sigma_hist, rf_ann, erp_ann, sigma_mkt)
    a_mu, b_mu, sigma_eff_same_mu = efficient_same_return(mu_capm, rf_ann, erp_ann, sigma_mkt)

    progress(0.7, desc="Generating candidate portfolios…")
    user_universe = list(symbols)
    synth = build_synthetic_dataset(user_universe, covA, betas, rf_ann, erp_ann, sigma_mkt, n_rows=SYNTH_ROWS)
    csv_path = os.path.join(DATA_DIR, f"investor_profiles_{int(time.time())}.csv")
    try:
        synth.to_csv(csv_path, index=False)
    except Exception:
        csv_path = None

    progress(0.85, desc="Selecting suggestions…")
    picks = suggest_one_per_band(synth, sigma_mkt, user_universe)

    def _fmt(row: pd.Series) -> str:
        if row is None or row.empty:
            return "No pick available."
        return f"CAPM E[r] {row['mu_capm']*100:.2f}%, σ(h) {row['sigma_hist']*100:.2f}%"

    txt_low   = _fmt(picks.get("low", pd.Series(dtype=object)))
    txt_med   = _fmt(picks.get("medium", pd.Series(dtype=object)))
    txt_high  = _fmt(picks.get("high", pd.Series(dtype=object)))

    chosen_band = (pick_band_to_show or "Medium").strip().lower()
    chosen = picks.get(chosen_band, pd.Series(dtype=object))
    if chosen is None or chosen.empty:
        chosen_sigma = None
        chosen_mu = None
        sugg_table = empty_suggestion_df()
    else:
        chosen_sigma = float(chosen["sigma_hist"])
        chosen_mu = float(chosen["mu_capm"])
        sugg_table = _holdings_table_from_row(chosen, budget=gross)

    pos_table = pd.DataFrame(
        [{
            "ticker": t,
            "amount_usd": amounts.get(t, 0.0),
            "weight_exposure": weights.get(t, 0.0),
            "beta": 1.0 if t == MARKET_TICKER else betas.get(t, np.nan)
        } for t in symbols],
        columns=["ticker", "amount_usd", "weight_exposure", "beta"]
    )

    img = plot_cml(
        rf_ann, erp_ann, sigma_mkt,
        sigma_hist, mu_capm,
        mu_eff_same_sigma, sigma_eff_same_mu,
        sugg_sigma_hist=chosen_sigma, sugg_mu_capm=chosen_mu
    )

    info = "\n".join([
        "### Inputs",
        f"- Lookback years {years_lookback}",
        f"- Horizon years {int(round(HORIZON_YEARS))}",
        f"- Risk-free {rf_ann:.2%} from {RF_CODE}",
        f"- Market ERP {erp_ann:.2%}",
        f"- Market σ (hist) {sigma_mkt:.2%}",
        "",
        "### Your portfolio",
        f"- CAPM E[r] {mu_capm:.2%}",
        f"- σ (historical) {sigma_hist:.2%}",
        "",
        "### Efficient market/bills mixes (replication weights)",
        f"- **Same σ as your portfolio** → Market weight **{a_sigma:.2f}**, Bills weight **{b_sigma:.2f}** → E[r] **{mu_eff_same_sigma:.2%}**",
        f"- **Same E[r] as your portfolio** → Market weight **{a_mu:.2f}**, Bills weight **{b_mu:.2f}** → σ **{sigma_eff_same_mu:.2%}**",
        "",
        "_How to replicate:_ use a broad market ETF (e.g., VOO) for **Market** and a T-bill/money-market fund for **Bills**. ",
        "Weights can be >1 or negative. If leverage isn’t allowed, scale both weights proportionally toward 1.0.",
    ])

    # Final yield: results + reveal right column and suggestion row; hide banner
    yield (
        img, info, pos_table, sugg_table, csv_path,
        txt_low, txt_med, txt_high,
        gr.update(visible=True),
        gr.update(visible=True),
        gr.update(value="", visible=False)
    )

# -------------- UI --------------
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
:root {
  --lensiq-accent: #8b5cf6;
  --lensiq-bg: #0b1220;
  --lensiq-card: #121a2b;
  --lensiq-text: #e5e7eb;
}
.gradio-container { font-family: Inter, ui-sans-serif, system-ui, -apple-system !important; }
.lensiq-card { background: var(--lensiq-card); border-radius: 14px; padding: 14px; }
button, .gr-button { border-radius: 10px !important; }
.lensiq-status { background: #1f2937; color: #e5e7eb; border-left: 4px solid var(--lensiq-accent); padding: 10px 12px; border-radius: 8px; }
"""

with gr.Blocks(title="Efficient Portfolio Advisor", css=custom_css) as demo:
    gr.Markdown("## Efficient Portfolio Advisor")

    with gr.Row():
        # LEFT COLUMN (full width pre-compute)
        with gr.Column(scale=1) as left_col:
            with gr.Group(elem_classes="lensiq-card"):
                q = gr.Textbox(label="Search symbol")
                search_btn = gr.Button("Search")
                matches = gr.Dropdown(choices=[], label="Matches", info="Type a query and hit Search")
                add_btn = gr.Button("Add selected to portfolio")

            with gr.Group(elem_classes="lensiq-card"):
                gr.Markdown("### Portfolio positions")
                table = gr.Dataframe(
                    headers=["ticker", "amount_usd"],
                    datatype=["str", "number"],
                    row_count=0,
                    col_count=(2, "fixed")
                )

                # remove controls
                with gr.Row():
                    rm_dropdown = gr.Dropdown(choices=[], label="Remove ticker", value=None)
                    rm_btn = gr.Button("Remove selected")

            with gr.Group(elem_classes="lensiq-card"):
                horizon = gr.Number(label="Horizon in years (1–100)", value=HORIZON_YEARS, precision=0)
                lookback = gr.Slider(1, 15, value=DEFAULT_LOOKBACK_YEARS, step=1, label="Lookback years for betas & covariances")
                run_btn = gr.Button("Compute (build dataset & suggest)")

            # visible loading/status banner
            status_md = gr.Markdown("", visible=False, elem_classes="lensiq-status")

            sugg_hdr = gr.Markdown("### Suggestions", visible=False)
            with gr.Row(visible=False) as sugg_row:
                btn_low = gr.Button("Show Low")
                btn_med = gr.Button("Show Medium")
                btn_high = gr.Button("Show High")
            low_txt = gr.Markdown()
            med_txt = gr.Markdown()
            high_txt = gr.Markdown()

        # RIGHT COLUMN (hidden pre-compute)
        with gr.Column(scale=1, visible=False) as right_col:
            plot = gr.Image(label="Capital Market Line (CAPM)", type="pil")
            summary = gr.Markdown(label="Inputs & Results")
            positions = gr.Dataframe(
                label="Computed positions",
                headers=["ticker", "amount_usd", "weight_exposure", "beta"],
                datatype=["str", "number", "number", "number"],
                col_count=(4, "fixed"),
                value=empty_positions_df(),
                interactive=False
            )
            sugg_table = gr.Dataframe(
                label="Selected suggestion holdings (% / $)",
                headers=["ticker", "weight_%", "amount_$"],
                datatype=["str", "number", "number"],
                col_count=(3, "fixed"),
                value=empty_suggestion_df(),
                interactive=False
            )
            dl = gr.File(label="Generated dataset CSV", value=None, visible=True)

    # ---------- wiring ----------
    # search / add
    search_btn.click(fn=search_tickers_cb, inputs=q, outputs=matches)
    add_btn.click(fn=add_symbol_table_only, inputs=[matches, table], outputs=table)

    # keep tickers valid & refresh remove dropdown when table changes
    table.change(fn=lock_ticker_column, inputs=table, outputs=table)
    table.change(fn=current_ticker_choices, inputs=table, outputs=rm_dropdown)

    # remove a ticker
    rm_btn.click(fn=remove_selected_ticker, inputs=[rm_dropdown, table], outputs=[table, rm_dropdown])

    # horizon updates globals silently
    horizon.change(fn=set_horizon, inputs=horizon, outputs=[])

    # compute + reveal results (default Medium band); STREAMING for visible progress
    run_btn.click(
        fn=compute_stream,
        inputs=[lookback, table, gr.State("Medium")],
        outputs=[plot, summary, positions, sugg_table, dl, low_txt, med_txt, high_txt, right_col, sugg_row, status_md]
    ).then(  # after results are visible, show Suggestions header too
        lambda: (gr.update(visible=True),),
        None,
        [sugg_hdr]
    )

    # band buttons recompute picks quickly (also stream with banner)
    btn_low.click(
        fn=compute_stream,
        inputs=[lookback, table, gr.State("Low")],
        outputs=[plot, summary, positions, sugg_table, dl, low_txt, med_txt, high_txt, right_col, sugg_row, status_md]
    )
    btn_med.click(
        fn=compute_stream,
        inputs=[lookback, table, gr.State("Medium")],
        outputs=[plot, summary, positions, sugg_table, dl, low_txt, med_txt, high_txt, right_col, sugg_row, status_md]
    )
    btn_high.click(
        fn=compute_stream,
        inputs=[lookback, table, gr.State("High")],
        outputs=[plot, summary, positions, sugg_table, dl, low_txt, med_txt, high_txt, right_col, sugg_row, status_md]
    )

# initialize risk-free at launch
RF_CODE = fred_series_for_horizon(HORIZON_YEARS)
RF_ANN = fetch_fred_yield_annual(RF_CODE)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)