Spaces:
Sleeping
Sleeping
Timo
commited on
Commit
·
4205025
1
Parent(s):
2866e65
Works now
Browse files- src/draft_model.py +4 -2
- src/helpers.py +3 -3
- src/streamlit_app.py +48 -38
src/draft_model.py
CHANGED
|
@@ -22,7 +22,8 @@ ENCODING_FILE = "card_encodings.pt"
|
|
| 22 |
|
| 23 |
class DraftModel:
|
| 24 |
def __init__(self):
|
| 25 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 26 |
|
| 27 |
weight_path = hf_hub_download(
|
| 28 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
|
@@ -37,9 +38,10 @@ class DraftModel:
|
|
| 37 |
# ---- load network ---------------------------------------------------
|
| 38 |
|
| 39 |
|
| 40 |
-
self.net = MLP_CrossAttention(**cfg)
|
| 41 |
self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
|
| 42 |
self.net.eval()
|
|
|
|
| 43 |
|
| 44 |
# ---- embeddings – one-time load ------------------------------------
|
| 45 |
self.embed_dict = get_embedding_dict(
|
|
|
|
| 22 |
|
| 23 |
class DraftModel:
|
| 24 |
def __init__(self):
|
| 25 |
+
#self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
self.device = torch.device("cpu") # Force CPU for compatibility
|
| 27 |
|
| 28 |
weight_path = hf_hub_download(
|
| 29 |
repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model"
|
|
|
|
| 38 |
# ---- load network ---------------------------------------------------
|
| 39 |
|
| 40 |
|
| 41 |
+
self.net = MLP_CrossAttention(**cfg)
|
| 42 |
self.net.load_state_dict(torch.load(weight_path, map_location=self.device))
|
| 43 |
self.net.eval()
|
| 44 |
+
self.net.to(self.device)
|
| 45 |
|
| 46 |
# ---- embeddings – one-time load ------------------------------------
|
| 47 |
self.embed_dict = get_embedding_dict(
|
src/helpers.py
CHANGED
|
@@ -194,13 +194,13 @@ class MLP_CrossAttention(nn.Module):
|
|
| 194 |
|
| 195 |
deck = deck.view(batch_size * deck_size, card_size)
|
| 196 |
|
| 197 |
-
deck_encoded = self.card_encoder(deck
|
| 198 |
deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
|
| 199 |
|
| 200 |
|
| 201 |
# identify padded cards
|
| 202 |
-
mask = (cards.sum(dim=-1) != 0)
|
| 203 |
-
cards_encoded = self.card_encoder(cards
|
| 204 |
|
| 205 |
if not no_attention:
|
| 206 |
# Cross-attention
|
|
|
|
| 194 |
|
| 195 |
deck = deck.view(batch_size * deck_size, card_size)
|
| 196 |
|
| 197 |
+
deck_encoded = self.card_encoder(deck)
|
| 198 |
deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
|
| 199 |
|
| 200 |
|
| 201 |
# identify padded cards
|
| 202 |
+
mask = (cards.sum(dim=-1) != 0)
|
| 203 |
+
cards_encoded = self.card_encoder(cards)
|
| 204 |
|
| 205 |
if not no_attention:
|
| 206 |
# Cross-attention
|
src/streamlit_app.py
CHANGED
|
@@ -47,7 +47,22 @@ SUPPORTED_SETS_PATH = Path("src/helper_files/supported_sets.txt")
|
|
| 47 |
def load_model():
|
| 48 |
return DraftModel()
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
@st.cache_data(show_spinner="Reading supported sets …")
|
| 53 |
def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
|
|
@@ -68,7 +83,10 @@ def p1p1_ranking(set_code: str):
|
|
| 68 |
|
| 69 |
@st.cache_data(show_spinner="Calculating card rankings …")
|
| 70 |
def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
| 72 |
pick = out["pick"]
|
| 73 |
scores = {pack[i]: score for i, score in enumerate(out["scores"])}
|
| 74 |
return pick, scores
|
|
@@ -127,19 +145,12 @@ st.session_state.setdefault("undo_stack", [])
|
|
| 127 |
# -------- Main content organised in tabs ------------------------------------
|
| 128 |
|
| 129 |
tabs = st.tabs(["Draft", "P1P1 Rankings"])
|
| 130 |
-
|
| 131 |
-
"""Make sure session_state has the pack and picks lists."""
|
| 132 |
-
if "pack" not in st.session_state:
|
| 133 |
-
st.session_state["pack"] = []
|
| 134 |
-
if "picks" not in st.session_state:
|
| 135 |
-
st.session_state["picks"] = []
|
| 136 |
def add_card(target: str, card: str):
|
| 137 |
"""target is 'pack' or 'picks'."""
|
| 138 |
-
ensure_state()
|
| 139 |
st.session_state[target].append(card)
|
| 140 |
|
| 141 |
def remove_card(target: str, key: str):
|
| 142 |
-
ensure_state()
|
| 143 |
lst = st.session_state[target]
|
| 144 |
idx = next((i for i, c in enumerate(lst) if c == key), None)
|
| 145 |
if idx is not None:
|
|
@@ -147,7 +158,6 @@ def remove_card(target: str, key: str):
|
|
| 147 |
|
| 148 |
def push_undo():
|
| 149 |
"""Save a snapshot of pack + picks so we can undo one step."""
|
| 150 |
-
ensure_state()
|
| 151 |
st.session_state["undo_stack"].append({
|
| 152 |
"pack": copy.deepcopy(st.session_state["pack"]),
|
| 153 |
"picks": copy.deepcopy(st.session_state["picks"]),
|
|
@@ -161,7 +171,6 @@ def undo_last():
|
|
| 161 |
snap = st.session_state["undo_stack"].pop()
|
| 162 |
st.session_state["pack"] = snap["pack"]
|
| 163 |
st.session_state["picks"] = snap["picks"]
|
| 164 |
-
st.rerun()
|
| 165 |
|
| 166 |
# --- callbacks ---
|
| 167 |
def _add_selected_to_deck():
|
|
@@ -170,7 +179,6 @@ def _add_selected_to_deck():
|
|
| 170 |
add_card("picks", val)
|
| 171 |
st.session_state["deck_selectbox"] = None # clear selection
|
| 172 |
st.toast(f"Added to deck: {val}")
|
| 173 |
-
st.rerun()
|
| 174 |
|
| 175 |
def _add_selected_to_pack():
|
| 176 |
val = st.session_state.get("pack_selectbox")
|
|
@@ -178,14 +186,12 @@ def _add_selected_to_pack():
|
|
| 178 |
add_card("pack", val)
|
| 179 |
st.session_state["pack_selectbox"] = None # clear selection
|
| 180 |
st.toast(f"Added to pack: {val}")
|
| 181 |
-
st.rerun()
|
| 182 |
|
| 183 |
|
| 184 |
# --- Tab 1: Draft ------------------------------------------------------------
|
| 185 |
|
| 186 |
with tabs[0]:
|
| 187 |
-
ensure_state()
|
| 188 |
-
|
| 189 |
|
| 190 |
if st.session_state["undo_stack"]:
|
| 191 |
st.button("↩️ Undo last action", on_click=undo_last)
|
|
@@ -220,11 +226,12 @@ with tabs[0]:
|
|
| 220 |
# Show current deck with remove buttons
|
| 221 |
if st.session_state["picks"]:
|
| 222 |
for i, card in enumerate(st.session_state["picks"]):
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
| 228 |
else:
|
| 229 |
st.caption("Deck is empty.")
|
| 230 |
|
|
@@ -243,26 +250,29 @@ with tabs[0]:
|
|
| 243 |
pack_list = st.session_state["pack"]
|
| 244 |
vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
|
| 245 |
df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
push_undo()
|
| 261 |
st.session_state["picks"].append(card)
|
| 262 |
-
st.session_state["pack"] = []
|
| 263 |
st.rerun()
|
| 264 |
-
else:
|
| 265 |
-
st.caption("Pack is empty.")
|
| 266 |
|
| 267 |
# --- Tab 2: Card rankings ----------------------------------------------------
|
| 268 |
|
|
|
|
| 47 |
def load_model():
|
| 48 |
return DraftModel()
|
| 49 |
|
| 50 |
+
|
| 51 |
+
if "model" not in st.session_state:
|
| 52 |
+
st.session_state.model = load_model() # your class
|
| 53 |
+
if "deck" not in st.session_state:
|
| 54 |
+
st.session_state.deck: List[str] = []
|
| 55 |
+
if "pack" not in st.session_state:
|
| 56 |
+
st.session_state.pack: List[str] = []
|
| 57 |
+
if "undo_stack" not in st.session_state:
|
| 58 |
+
st.session_state.undo_stack: List[str] = []
|
| 59 |
+
if "set_code" not in st.session_state:
|
| 60 |
+
# choose a default set code that exists in model.cards, e.g., "eoe"
|
| 61 |
+
st.session_state.set_code = "eoe"
|
| 62 |
+
|
| 63 |
+
model = st.session_state.model
|
| 64 |
+
|
| 65 |
+
|
| 66 |
|
| 67 |
@st.cache_data(show_spinner="Reading supported sets …")
|
| 68 |
def get_supported_sets(path: Path = SUPPORTED_SETS_PATH) -> List[str]:
|
|
|
|
| 83 |
|
| 84 |
@st.cache_data(show_spinner="Calculating card rankings …")
|
| 85 |
def rank_cards(deck: List[str], pack: List[str]) -> List[Dict]:
|
| 86 |
+
if not deck:
|
| 87 |
+
out = model.predict(pack, deck = None)
|
| 88 |
+
else:
|
| 89 |
+
out = model.predict(pack, deck = deck)
|
| 90 |
pick = out["pick"]
|
| 91 |
scores = {pack[i]: score for i, score in enumerate(out["scores"])}
|
| 92 |
return pick, scores
|
|
|
|
| 145 |
# -------- Main content organised in tabs ------------------------------------
|
| 146 |
|
| 147 |
tabs = st.tabs(["Draft", "P1P1 Rankings"])
|
| 148 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def add_card(target: str, card: str):
|
| 150 |
"""target is 'pack' or 'picks'."""
|
|
|
|
| 151 |
st.session_state[target].append(card)
|
| 152 |
|
| 153 |
def remove_card(target: str, key: str):
|
|
|
|
| 154 |
lst = st.session_state[target]
|
| 155 |
idx = next((i for i, c in enumerate(lst) if c == key), None)
|
| 156 |
if idx is not None:
|
|
|
|
| 158 |
|
| 159 |
def push_undo():
|
| 160 |
"""Save a snapshot of pack + picks so we can undo one step."""
|
|
|
|
| 161 |
st.session_state["undo_stack"].append({
|
| 162 |
"pack": copy.deepcopy(st.session_state["pack"]),
|
| 163 |
"picks": copy.deepcopy(st.session_state["picks"]),
|
|
|
|
| 171 |
snap = st.session_state["undo_stack"].pop()
|
| 172 |
st.session_state["pack"] = snap["pack"]
|
| 173 |
st.session_state["picks"] = snap["picks"]
|
|
|
|
| 174 |
|
| 175 |
# --- callbacks ---
|
| 176 |
def _add_selected_to_deck():
|
|
|
|
| 179 |
add_card("picks", val)
|
| 180 |
st.session_state["deck_selectbox"] = None # clear selection
|
| 181 |
st.toast(f"Added to deck: {val}")
|
|
|
|
| 182 |
|
| 183 |
def _add_selected_to_pack():
|
| 184 |
val = st.session_state.get("pack_selectbox")
|
|
|
|
| 186 |
add_card("pack", val)
|
| 187 |
st.session_state["pack_selectbox"] = None # clear selection
|
| 188 |
st.toast(f"Added to pack: {val}")
|
| 189 |
+
#st.rerun()
|
| 190 |
|
| 191 |
|
| 192 |
# --- Tab 1: Draft ------------------------------------------------------------
|
| 193 |
|
| 194 |
with tabs[0]:
|
|
|
|
|
|
|
| 195 |
|
| 196 |
if st.session_state["undo_stack"]:
|
| 197 |
st.button("↩️ Undo last action", on_click=undo_last)
|
|
|
|
| 226 |
# Show current deck with remove buttons
|
| 227 |
if st.session_state["picks"]:
|
| 228 |
for i, card in enumerate(st.session_state["picks"]):
|
| 229 |
+
name_col, rm_col = st.columns([6, 3], gap="small")
|
| 230 |
+
name_col.write(card)
|
| 231 |
+
with rm_col:
|
| 232 |
+
if st.button("Remove", key=f"rm-deck-{i}", use_container_width=True):
|
| 233 |
+
remove_card("picks", card)
|
| 234 |
+
st.rerun()
|
| 235 |
else:
|
| 236 |
st.caption("Deck is empty.")
|
| 237 |
|
|
|
|
| 250 |
pack_list = st.session_state["pack"]
|
| 251 |
vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
|
| 252 |
df_scores = pd.DataFrame({"Card": pack_list, "Score": vals})
|
| 253 |
+
|
| 254 |
+
pack_list = st.session_state["pack"]
|
| 255 |
+
vals = [scores.get(c) if scores and c in scores else np.nan for c in pack_list]
|
| 256 |
+
|
| 257 |
+
# header row
|
| 258 |
+
h1, h2, h3 = st.columns([6, 2, 3])
|
| 259 |
+
h1.markdown("**Card**")
|
| 260 |
+
h2.markdown("**Score**")
|
| 261 |
+
h3.markdown("**Pick**")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# rows
|
| 265 |
+
for i, card in enumerate(pack_list):
|
| 266 |
+
s = vals[i]
|
| 267 |
+
c1, c2, c3 = st.columns([6, 2, 3], gap="small")
|
| 268 |
+
c1.write(card)
|
| 269 |
+
c2.markdown(f"<div class='score-cell'>{'' if np.isnan(s) else f'{s:.4f}'}</div>", unsafe_allow_html=True)
|
| 270 |
+
with c3:
|
| 271 |
+
if st.button("Pick", key=f"pick_btn_{i}", use_container_width=True, help="Add to deck & clear pack"):
|
| 272 |
push_undo()
|
| 273 |
st.session_state["picks"].append(card)
|
| 274 |
+
st.session_state["pack"] = [] # or generate_booster(set_code)
|
| 275 |
st.rerun()
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# --- Tab 2: Card rankings ----------------------------------------------------
|
| 278 |
|