Update app.py
Browse files
app.py
CHANGED
|
@@ -49,6 +49,16 @@ except Exception:
|
|
| 49 |
|
| 50 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
try:
|
| 53 |
import faiss
|
| 54 |
except Exception:
|
|
|
|
| 49 |
|
| 50 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 51 |
|
| 52 |
+
class StopOnTokens(StoppingCriteria):
|
| 53 |
+
def __init__(self, stop_token_ids: List[int]):
|
| 54 |
+
self.stop_token_ids = stop_token_ids
|
| 55 |
+
|
| 56 |
+
def __call__(self, input_ids: "torch.LongTensor", scores: "torch.FloatTensor", **kwargs) -> bool:
|
| 57 |
+
for stop_id in self.stop_token_ids:
|
| 58 |
+
if input_ids[0][-1] == stop_id:
|
| 59 |
+
return True
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
try:
|
| 63 |
import faiss
|
| 64 |
except Exception:
|