Paulhayes commited on
Commit
b8fe16c
·
verified ·
1 Parent(s): 620f5d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -49,6 +49,16 @@ except Exception:
49
 
50
  from transformers import StoppingCriteria, StoppingCriteriaList
51
 
 
 
 
 
 
 
 
 
 
 
52
  try:
53
  import faiss
54
  except Exception:
 
49
 
50
  from transformers import StoppingCriteria, StoppingCriteriaList
51
 
52
+ class StopOnTokens(StoppingCriteria):
53
+ def __init__(self, stop_token_ids: List[int]):
54
+ self.stop_token_ids = stop_token_ids
55
+
56
+ def __call__(self, input_ids: "torch.LongTensor", scores: "torch.FloatTensor", **kwargs) -> bool:
57
+ for stop_id in self.stop_token_ids:
58
+ if input_ids[0][-1] == stop_id:
59
+ return True
60
+ return False
61
+
62
  try:
63
  import faiss
64
  except Exception: