Spaces:
Running
on
Zero
Running
on
Zero
da03
commited on
Commit
·
3098025
1
Parent(s):
ae9daf1
app.py
CHANGED
|
@@ -40,15 +40,15 @@ def predict_product(num1, num2):
|
|
| 40 |
generated_ids = inputs['input_ids']
|
| 41 |
past_key_values = None
|
| 42 |
for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
|
| 43 |
-
outputs = model(
|
| 44 |
input_ids=generated_ids,
|
|
|
|
| 45 |
past_key_values=past_key_values,
|
|
|
|
| 46 |
use_cache=True
|
| 47 |
)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
| 51 |
-
generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
|
| 52 |
print (next_token_id)
|
| 53 |
|
| 54 |
if next_token_id.item() == eos_token_id:
|
|
|
|
| 40 |
generated_ids = inputs['input_ids']
|
| 41 |
past_key_values = None
|
| 42 |
for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
|
| 43 |
+
outputs = model.generate(
|
| 44 |
input_ids=generated_ids,
|
| 45 |
+
max_new_tokens=1,
|
| 46 |
past_key_values=past_key_values,
|
| 47 |
+
return_dict_in_generate=True,
|
| 48 |
use_cache=True
|
| 49 |
)
|
| 50 |
+
generated_ids = outputs.sequences
|
| 51 |
+
next_token_id = generated_ids[0, -1]
|
|
|
|
|
|
|
| 52 |
print (next_token_id)
|
| 53 |
|
| 54 |
if next_token_id.item() == eos_token_id:
|