Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -461,8 +461,9 @@ class ModelWrapper:
|
|
| 461 |
|
| 462 |
print(f"✅ Model loaded: {self.d_model}d × {self.n_layers}L × {self.n_heads}H")
|
| 463 |
|
| 464 |
-
def
|
| 465 |
-
|
|
|
|
| 466 |
# Format prompt correctly (NO newline between User: and Sam:)
|
| 467 |
if not prompt.startswith("User:"):
|
| 468 |
prompt = f"User: {prompt} Sam:"
|
|
@@ -479,6 +480,7 @@ class ModelWrapper:
|
|
| 479 |
|
| 480 |
rng = random.PRNGKey(42)
|
| 481 |
generated_ids = input_ids
|
|
|
|
| 482 |
|
| 483 |
# Generate tokens
|
| 484 |
for _ in range(max_new_tokens):
|
|
@@ -504,21 +506,17 @@ class ModelWrapper:
|
|
| 504 |
|
| 505 |
generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
|
| 506 |
|
|
|
|
|
|
|
|
|
|
| 507 |
# Stop on EOS
|
| 508 |
-
if
|
| 509 |
break
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
response = generated_text.split("Sam:")[-1].strip()
|
| 516 |
-
# Clean up
|
| 517 |
-
if "<|endoftext|>" in response:
|
| 518 |
-
response = response.split("<|endoftext|>")[0].strip()
|
| 519 |
-
return response
|
| 520 |
-
else:
|
| 521 |
-
return generated_text
|
| 522 |
|
| 523 |
# ==============================================================================
|
| 524 |
# GRADIO INTERFACE
|
|
|
|
| 461 |
|
| 462 |
print(f"✅ Model loaded: {self.d_model}d × {self.n_layers}L × {self.n_heads}H")
|
| 463 |
|
| 464 |
+
def generate_stream(self, prompt: str, max_new_tokens: int = 200,
|
| 465 |
+
temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
|
| 466 |
+
"""Generator that yields tokens one at a time for streaming"""
|
| 467 |
# Format prompt correctly (NO newline between User: and Sam:)
|
| 468 |
if not prompt.startswith("User:"):
|
| 469 |
prompt = f"User: {prompt} Sam:"
|
|
|
|
| 480 |
|
| 481 |
rng = random.PRNGKey(42)
|
| 482 |
generated_ids = input_ids
|
| 483 |
+
response_text = ""
|
| 484 |
|
| 485 |
# Generate tokens
|
| 486 |
for _ in range(max_new_tokens):
|
|
|
|
| 506 |
|
| 507 |
generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
|
| 508 |
|
| 509 |
+
# Decode the new token
|
| 510 |
+
token_id = int(next_token[0, 0])
|
| 511 |
+
|
| 512 |
# Stop on EOS
|
| 513 |
+
if token_id == self.tokenizer.token_to_id("<|endoftext|>"):
|
| 514 |
break
|
| 515 |
+
|
| 516 |
+
# Decode and yield the token
|
| 517 |
+
token_text = self.tokenizer.decode([token_id])
|
| 518 |
+
response_text += token_text
|
| 519 |
+
yield response_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
# ==============================================================================
|
| 522 |
# GRADIO INTERFACE
|