Spaces:
Sleeping
Sleeping
space update
Browse files
app.py
CHANGED
|
@@ -71,7 +71,23 @@ def load_model(custom_model_path=None):
|
|
| 71 |
|
| 72 |
if os.path.exists(model_path):
|
| 73 |
try:
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
u_model.eval()
|
| 76 |
print("β
Model weights loaded successfully!")
|
| 77 |
return u_model, u_tokenizer, f"β
Model loaded from: {model_path}"
|
|
@@ -188,15 +204,13 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
| 188 |
gr.Markdown("### π Model Upload (Optional)")
|
| 189 |
model_file = gr.File(
|
| 190 |
label="Upload your own model.pth file",
|
| 191 |
-
file_types=[".pth", ".pt"]
|
| 192 |
-
info="Upload a custom UstaModel checkpoint to use instead of the default model"
|
| 193 |
)
|
| 194 |
upload_btn = gr.Button("Load Model", variant="primary")
|
| 195 |
model_status_display = gr.Textbox(
|
| 196 |
label="Model Status",
|
| 197 |
value=model_status,
|
| 198 |
-
interactive=False
|
| 199 |
-
info="Shows the current model loading status"
|
| 200 |
)
|
| 201 |
|
| 202 |
with gr.Column(scale=1):
|
|
@@ -205,8 +219,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
| 205 |
gr.Markdown("### βοΈ Generation Settings")
|
| 206 |
system_msg = gr.Textbox(
|
| 207 |
value="You are Usta, a geographical knowledge assistant trained from scratch.",
|
| 208 |
-
label="System message"
|
| 209 |
-
info="Note: This model focuses on geographical knowledge"
|
| 210 |
)
|
| 211 |
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
|
| 212 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
|
|
@@ -215,8 +228,7 @@ with gr.Blocks(title="π€ Usta Model Chat", theme=gr.themes.Soft()) as demo:
|
|
| 215 |
maximum=1.0,
|
| 216 |
value=0.95,
|
| 217 |
step=0.05,
|
| 218 |
-
label="Top-p (nucleus sampling)"
|
| 219 |
-
info="Note: This parameter is not used by UstaModel"
|
| 220 |
)
|
| 221 |
|
| 222 |
# Chat interface
|
|
|
|
| 71 |
|
| 72 |
if os.path.exists(model_path):
|
| 73 |
try:
|
| 74 |
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 75 |
+
|
| 76 |
+
# Handle potential key mapping issues
|
| 77 |
+
if "embedding.weight" in state_dict and "embedding.embedding.weight" not in state_dict:
|
| 78 |
+
# Map old key names to new key names
|
| 79 |
+
new_state_dict = {}
|
| 80 |
+
for key, value in state_dict.items():
|
| 81 |
+
if key == "embedding.weight":
|
| 82 |
+
new_state_dict["embedding.embedding.weight"] = value
|
| 83 |
+
elif key == "pos_embedding.weight":
|
| 84 |
+
# Skip positional embedding if not expected
|
| 85 |
+
continue
|
| 86 |
+
else:
|
| 87 |
+
new_state_dict[key] = value
|
| 88 |
+
state_dict = new_state_dict
|
| 89 |
+
|
| 90 |
+
u_model.load_state_dict(state_dict)
|
| 91 |
u_model.eval()
|
| 92 |
print("β
Model weights loaded successfully!")
|
| 93 |
return u_model, u_tokenizer, f"β
Model loaded from: {model_path}"
|
|
|
|
| 204 |
gr.Markdown("### π Model Upload (Optional)")
|
| 205 |
model_file = gr.File(
|
| 206 |
label="Upload your own model.pth file",
|
| 207 |
+
file_types=[".pth", ".pt"]
|
|
|
|
| 208 |
)
|
| 209 |
upload_btn = gr.Button("Load Model", variant="primary")
|
| 210 |
model_status_display = gr.Textbox(
|
| 211 |
label="Model Status",
|
| 212 |
value=model_status,
|
| 213 |
+
interactive=False
|
|
|
|
| 214 |
)
|
| 215 |
|
| 216 |
with gr.Column(scale=1):
|
|
|
|
| 219 |
gr.Markdown("### βοΈ Generation Settings")
|
| 220 |
system_msg = gr.Textbox(
|
| 221 |
value="You are Usta, a geographical knowledge assistant trained from scratch.",
|
| 222 |
+
label="System message"
|
|
|
|
| 223 |
)
|
| 224 |
max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens")
|
| 225 |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
|
|
|
|
| 228 |
maximum=1.0,
|
| 229 |
value=0.95,
|
| 230 |
step=0.05,
|
| 231 |
+
label="Top-p (nucleus sampling)"
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
# Chat interface
|