Keeby-smilyai commited on
Commit
f49b7f0
·
verified ·
1 Parent(s): 2eb592a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -461,8 +461,9 @@ class ModelWrapper:
461
 
462
  print(f"✅ Model loaded: {self.d_model}d × {self.n_layers}L × {self.n_heads}H")
463
 
464
- def generate(self, prompt: str, max_new_tokens: int = 200,
465
- temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
 
466
  # Format prompt correctly (NO newline between User: and Sam:)
467
  if not prompt.startswith("User:"):
468
  prompt = f"User: {prompt} Sam:"
@@ -479,6 +480,7 @@ class ModelWrapper:
479
 
480
  rng = random.PRNGKey(42)
481
  generated_ids = input_ids
 
482
 
483
  # Generate tokens
484
  for _ in range(max_new_tokens):
@@ -504,21 +506,17 @@ class ModelWrapper:
504
 
505
  generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
506
 
 
 
 
507
  # Stop on EOS
508
- if next_token[0, 0] == self.tokenizer.token_to_id("<|endoftext|>"):
509
  break
510
-
511
- generated_text = self.tokenizer.decode(generated_ids[0].tolist())
512
-
513
- # Extract response after "Sam:"
514
- if "Sam:" in generated_text:
515
- response = generated_text.split("Sam:")[-1].strip()
516
- # Clean up
517
- if "<|endoftext|>" in response:
518
- response = response.split("<|endoftext|>")[0].strip()
519
- return response
520
- else:
521
- return generated_text
522
 
523
  # ==============================================================================
524
  # GRADIO INTERFACE
 
461
 
462
  print(f"✅ Model loaded: {self.d_model}d × {self.n_layers}L × {self.n_heads}H")
463
 
464
+ def generate_stream(self, prompt: str, max_new_tokens: int = 200,
465
+ temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
466
+ """Generator that yields tokens one at a time for streaming"""
467
  # Format prompt correctly (NO newline between User: and Sam:)
468
  if not prompt.startswith("User:"):
469
  prompt = f"User: {prompt} Sam:"
 
480
 
481
  rng = random.PRNGKey(42)
482
  generated_ids = input_ids
483
+ response_text = ""
484
 
485
  # Generate tokens
486
  for _ in range(max_new_tokens):
 
506
 
507
  generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
508
 
509
+ # Decode the new token
510
+ token_id = int(next_token[0, 0])
511
+
512
  # Stop on EOS
513
+ if token_id == self.tokenizer.token_to_id("<|endoftext|>"):
514
  break
515
+
516
+ # Decode and yield the token
517
+ token_text = self.tokenizer.decode([token_id])
518
+ response_text += token_text
519
+ yield response_text
 
 
 
 
 
 
 
520
 
521
  # ==============================================================================
522
  # GRADIO INTERFACE