Spaces:
Paused
Paused
Update app_chat.py
Browse files- app_chat.py +1 -14
app_chat.py
CHANGED
|
@@ -30,20 +30,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
| 30 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
| 31 |
#tokenizer.use_default_system_prompt = False
|
| 32 |
|
| 33 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
| 34 |
-
def __init__(self, tokenizer, stops = [], encounters=1):
|
| 35 |
-
super().__init__()
|
| 36 |
-
self.stops = [stop.to("cuda") for stop in stops]
|
| 37 |
-
self.tokenizer = tokenizer
|
| 38 |
-
self.num_mamba_stop_ids = 8
|
| 39 |
-
|
| 40 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 41 |
-
last_token = input_ids[0][-self.num_mamba_stop_ids:]
|
| 42 |
-
for stop in self.stops:
|
| 43 |
-
if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
|
| 44 |
-
return True
|
| 45 |
-
return False
|
| 46 |
-
|
| 47 |
@spaces.GPU
|
| 48 |
def generate(
|
| 49 |
message: str,
|
|
@@ -66,6 +52,7 @@ def generate(
|
|
| 66 |
|
| 67 |
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
|
| 68 |
|
|
|
|
| 69 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 70 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 71 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
|
| 30 |
tokenizer.chat_template = "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}"
|
| 31 |
#tokenizer.use_default_system_prompt = False
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
@spaces.GPU
|
| 34 |
def generate(
|
| 35 |
message: str,
|
|
|
|
| 52 |
|
| 53 |
stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
|
| 54 |
|
| 55 |
+
|
| 56 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 57 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 58 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|