Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 32 |
trust_remote_code=True,
|
| 33 |
).eval()
|
| 34 |
|
| 35 |
-
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
|
| 36 |
|
| 37 |
class StopOnTokens(StoppingCriteria):
|
| 38 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
@@ -56,7 +56,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
| 56 |
print(f"Conversation is -\n{conversation}")
|
| 57 |
stop = StopOnTokens()
|
| 58 |
|
| 59 |
-
input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
|
| 60 |
#input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
| 61 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 62 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
|
@@ -64,8 +64,8 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
| 64 |
|
| 65 |
generate_kwargs = dict(
|
| 66 |
input_ids=input_ids,
|
| 67 |
-
max_new_tokens=max_new_tokens,
|
| 68 |
streamer=streamer,
|
|
|
|
| 69 |
do_sample=True,
|
| 70 |
top_k=1,
|
| 71 |
temperature=temperature,
|
|
|
|
| 32 |
trust_remote_code=True,
|
| 33 |
).eval()
|
| 34 |
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True, use_fast=False)
|
| 36 |
|
| 37 |
class StopOnTokens(StoppingCriteria):
|
| 38 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
|
| 56 |
print(f"Conversation is -\n{conversation}")
|
| 57 |
stop = StopOnTokens()
|
| 58 |
|
| 59 |
+
input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(next(model.parameters()).device)
|
| 60 |
#input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
|
| 61 |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 62 |
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
|
|
|
|
| 64 |
|
| 65 |
generate_kwargs = dict(
|
| 66 |
input_ids=input_ids,
|
|
|
|
| 67 |
streamer=streamer,
|
| 68 |
+
max_new_tokens=max_new_tokens,
|
| 69 |
do_sample=True,
|
| 70 |
top_k=1,
|
| 71 |
temperature=temperature,
|