Timo commited on
Commit
4205025
·
1 Parent(s): 2866e65
Files changed (3) hide show
  1. src/draft_model.py +4 -2
  2. src/helpers.py +3 -3
  3. src/streamlit_app.py +48 -38
src/draft_model.py CHANGED
@@ -22,7 +22,8 @@ ENCODING_FILE = "card_encodings.pt"
22
 
23
  class DraftModel:
24
  def __init__(self):
25
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
26
 
27
  weight_path = hf_hub_download(
28
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
@@ -37,9 +38,10 @@ class DraftModel:
37
  # ---- load network ---------------------------------------------------
38
 
39
 
40
- self.net = MLP_CrossAttention(**cfg).to(self.device)
41
  self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
42
  self.net.eval()
 
43
 
44
  # ---- embeddings – one-time load ------------------------------------
45
  self.embed_dict = get_embedding_dict(
 
22
 
23
  class DraftModel:
24
  def __init__(self):
25
+ #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ self.device = torch.device("cpu") # Force CPU for compatibility
27
 
28
  weight_path = hf_hub_download(
29
  repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
 
38
  # ---- load network ---------------------------------------------------
39
 
40
 
41
+ self.net = MLP_CrossAttention(**cfg)
42
  self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
43
  self.net.eval()
44
+ self.net.to(self.device)
45
 
46
  # ---- embeddings – one-time load ------------------------------------
47
  self.embed_dict = get_embedding_dict(
src/helpers.py CHANGED
@@ -194,13 +194,13 @@ class MLP_CrossAttention(nn.Module):
194
 
195
  deck = deck.view(batch_size * deck_size, card_size)
196
 
197
- deck_encoded = self.card_encoder(deck.cuda())
198
  deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
199
 
200
 
201
  # identify padded cards
202
- mask = (cards.sum(dim=-1) != 0).cuda()
203
- cards_encoded = self.card_encoder(cards.cuda())
204
 
205
  if not no_attention:
206
  # Cross-attention
 
194
 
195
  deck = deck.view(batch_size * deck_size, card_size)
196
 
197
+ deck_encoded = self.card_encoder(deck)
198
  deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
199
 
200
 
201
  # identify padded cards
202
+ mask = (cards.sum(dim=-1) != 0)
203
+ cards_encoded = self.card_encoder(cards)
204
 
205
  if not no_attention:
206
  # Cross-attention
src/streamlit_app.py CHANGED
@@ -47,7 +47,22 @@ SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
47
  def load_model():
48
  return DraftModel()
49
 
50
- model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  @st.cache_data(show_spinner="Reading supported sets …")
53
  def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
@@ -68,7 +83,10 @@ def p1p1_ranking(set_code: str):
68
 
69
  @st.cache_data(show_spinner="Calculating card rankings …")
70
  def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
71
- out = model.predict(pack, deck)
 
 
 
72
  pick = out["pick"]
73
  scores = {pack[i]: score for i, score in enumerate(out["scores"])}
74
  return pick, scores
@@ -127,19 +145,12 @@ st.session_state.setdefault("undo_stack", [])
127
  # -------- Main content organised in tabs ------------------------------------
128
 
129
  tabs = st.tabs(["Draft", "P1P1 Rankings"])
130
- def ensure_state():
131
- """Make sure session_state has the pack and picks lists."""
132
- if "pack" not in st.session_state:
133
- st.session_state["pack"] = []
134
- if "picks" not in st.session_state:
135
- st.session_state["picks"] = []
136
  def add_card(target: str, card: str):
137
  """target is 'pack' or 'picks'."""
138
- ensure_state()
139
  st.session_state[target].append(card)
140
 
141
  def remove_card(target: str, key: str):
142
- ensure_state()
143
  lst = st.session_state[target]
144
  idx = next((i for i, c in enumerate(lst) if c == key), None)
145
  if idx is not None:
@@ -147,7 +158,6 @@ def remove_card(target: str, key: str):
147
 
148
  def push_undo():
149
  """Save a snapshot of pack + picks so we can undo one step."""
150
- ensure_state()
151
  st.session_state["undo_stack"].append({
152
  "pack": copy.deepcopy(st.session_state["pack"]),
153
  "picks": copy.deepcopy(st.session_state["picks"]),
@@ -161,7 +171,6 @@ def undo_last():
161
  snap = st.session_state["undo_stack"].pop()
162
  st.session_state["pack"] = snap["pack"]
163
  st.session_state["picks"] = snap["picks"]
164
- st.rerun()
165
 
166
  # --- callbacks ---
167
  def _add_selected_to_deck():
@@ -170,7 +179,6 @@ def _add_selected_to_deck():
170
  add_card("picks", val)
171
  st.session_state["deck_selectbox"] = None # clear selection
172
  st.toast(f"Added to deck: {val}")
173
- st.rerun()
174
 
175
  def _add_selected_to_pack():
176
  val = st.session_state.get("pack_selectbox")
@@ -178,14 +186,12 @@ def _add_selected_to_pack():
178
  add_card("pack", val)
179
  st.session_state["pack_selectbox"] = None # clear selection
180
  st.toast(f"Added to pack: {val}")
181
- st.rerun()
182
 
183
 
184
  # --- Tab 1: Draft ------------------------------------------------------------
185
 
186
  with tabs[0]:
187
- ensure_state()
188
-
189
 
190
  if st.session_state["undo_stack"]:
191
  st.button("↩️ Undo last action", on_click=undo_last)
@@ -220,11 +226,12 @@ with tabs[0]:
220
  # Show current deck with remove buttons
221
  if st.session_state["picks"]:
222
  for i, card in enumerate(st.session_state["picks"]):
223
- cols = st.columns([6, 2])
224
- cols[0].write(card)
225
- if cols[1].button("Remove", key=f"rm-deck-{i}"):
226
- remove_card("picks", card)
227
- st.rerun()
 
228
  else:
229
  st.caption("Deck is empty.")
230
 
@@ -243,26 +250,29 @@ with tabs[0]:
243
  pack_list = st.session_state["pack"]
244
  vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
245
  df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
246
- tbl_col, btn_col = st.columns([4, 1], gap="small")
247
- with tbl_col:
248
- st.dataframe(
249
- df_scores,
250
- use_container_width=True,
251
- column_config={
252
- "Score": st.column_config.NumberColumn(format="%.4f")
253
- },
254
- hide_index=True,
255
- )
256
- with btn_col:
257
- for i, card in enumerate(pack_list):
258
- cols = st.columns([8, 4])
259
- if cols[1].button("Add", key=f"add_clear_{i}"):
 
 
 
 
 
260
  push_undo()
261
  st.session_state["picks"].append(card)
262
- st.session_state["pack"] = [] # or generate_booster(set_code)
263
  st.rerun()
264
- else:
265
- st.caption("Pack is empty.")
266
 
267
  # --- Tab 2: Card rankings ----------------------------------------------------
268
 
 
47
  def load_model():
48
  return DraftModel()
49
 
50
+
51
+ if "model" not in st.session_state:
52
+ st.session_state.model = load_model() # your class
53
+ if "deck" not in st.session_state:
54
+ st.session_state.deck: List[str] = []
55
+ if "pack" not in st.session_state:
56
+ st.session_state.pack: List[str] = []
57
+ if "undo_stack" not in st.session_state:
58
+ st.session_state.undo_stack: List[str] = []
59
+ if "set_code" not in st.session_state:
60
+ # choose a default set code that exists in model.cards, e.g., "eoe"
61
+ st.session_state.set_code = "eoe"
62
+
63
+ model = st.session_state.model
64
+
65
+
66
 
67
  @st.cache_data(show_spinner="Reading supported sets …")
68
  def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
 
83
 
84
  @st.cache_data(show_spinner="Calculating card rankings …")
85
  def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
86
+ if not deck:
87
+ out = model.predict(pack, deck = None)
88
+ else:
89
+ out = model.predict(pack, deck = deck)
90
  pick = out["pick"]
91
  scores = {pack[i]: score for i, score in enumerate(out["scores"])}
92
  return pick, scores
 
145
  # -------- Main content organised in tabs ------------------------------------
146
 
147
  tabs = st.tabs(["Draft", "P1P1 Rankings"])
148
+
 
 
 
 
 
149
  def add_card(target: str, card: str):
150
  """target is 'pack' or 'picks'."""
 
151
  st.session_state[target].append(card)
152
 
153
  def remove_card(target: str, key: str):
 
154
  lst = st.session_state[target]
155
  idx = next((i for i, c in enumerate(lst) if c == key), None)
156
  if idx is not None:
 
158
 
159
  def push_undo():
160
  """Save a snapshot of pack + picks so we can undo one step."""
 
161
  st.session_state["undo_stack"].append({
162
  "pack": copy.deepcopy(st.session_state["pack"]),
163
  "picks": copy.deepcopy(st.session_state["picks"]),
 
171
  snap = st.session_state["undo_stack"].pop()
172
  st.session_state["pack"] = snap["pack"]
173
  st.session_state["picks"] = snap["picks"]
 
174
 
175
  # --- callbacks ---
176
  def _add_selected_to_deck():
 
179
  add_card("picks", val)
180
  st.session_state["deck_selectbox"] = None # clear selection
181
  st.toast(f"Added to deck: {val}")
 
182
 
183
  def _add_selected_to_pack():
184
  val = st.session_state.get("pack_selectbox")
 
186
  add_card("pack", val)
187
  st.session_state["pack_selectbox"] = None # clear selection
188
  st.toast(f"Added to pack: {val}")
189
+ #st.rerun()
190
 
191
 
192
  # --- Tab 1: Draft ------------------------------------------------------------
193
 
194
  with tabs[0]:
 
 
195
 
196
  if st.session_state["undo_stack"]:
197
  st.button("↩️ Undo last action", on_click=undo_last)
 
226
  # Show current deck with remove buttons
227
  if st.session_state["picks"]:
228
  for i, card in enumerate(st.session_state["picks"]):
229
+ name_col, rm_col = st.columns([6, 3], gap="small")
230
+ name_col.write(card)
231
+ with rm_col:
232
+ if st.button("Remove", key=f"rm-deck-{i}", use_container_width=True):
233
+ remove_card("picks", card)
234
+ st.rerun()
235
  else:
236
  st.caption("Deck is empty.")
237
 
 
250
  pack_list = st.session_state["pack"]
251
  vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
252
  df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
253
+
254
+ pack_list = st.session_state["pack"]
255
+ vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
256
+
257
+ # header row
258
+ h1, h2, h3 = st.columns([6, 2, 3])
259
+ h1.markdown("**Card**")
260
+ h2.markdown("**Score**")
261
+ h3.markdown("**Pick**")
262
+
263
+
264
+ # rows
265
+ for i, card in enumerate(pack_list):
266
+ s = vals[i]
267
+ c1, c2, c3 = st.columns([6, 2, 3], gap="small")
268
+ c1.write(card)
269
+ c2.markdown(f"<div class='score-cell'>{'' if np.isnan(s) else f'{s:.4f}'}</div>", unsafe_allow_html=True)
270
+ with c3:
271
+ if st.button("Pick", key=f"pick_btn_{i}", use_container_width=True, help="Add to deck & clear pack"):
272
  push_undo()
273
  st.session_state["picks"].append(card)
274
+ st.session_state["pack"] = [] # or generate_booster(set_code)
275
  st.rerun()
 
 
276
 
277
  # --- Tab 2: Card rankings ----------------------------------------------------
278