broadfield-dev commited on
Commit
c2e3eb9
·
verified ·
1 Parent(s): fc052c0

Update blueprints/summarize.py

Browse files
Files changed (1) hide show
  1. blueprints/summarize.py +3 -3
blueprints/summarize.py CHANGED
@@ -14,7 +14,7 @@ MODEL_CACHE = {
14
  "pipeline": None
15
  }
16
 
17
- def get_pipeline(model_name, task_type):
18
  """
19
  Retrieves a pipeline from cache or loads it if it's new.
20
  """
@@ -25,7 +25,6 @@ def get_pipeline(model_name, task_type):
25
  return MODEL_CACHE["pipeline"]
26
 
27
  # Authentication
28
- hf_token = os.getenv("HF_TOKEN")
29
  if hf_token:
30
  HfFolder.save_token(hf_token)
31
 
@@ -52,7 +51,7 @@ def run_inference_logic(config):
52
  model_id = config['model_name']
53
  text = config['text']
54
  task_type = config['task_type']
55
-
56
  pipe = get_pipeline(model_id, task_type)
57
 
58
  if task_type == 'TOKEN_CLS':
@@ -93,6 +92,7 @@ def api_summarize():
93
  config = {
94
  "text": data['text'],
95
  "model_name": data.get('model_name', "facebook/bart-large-cnn"),
 
96
  # We force this for the specific summarization UI,
97
  # but the backend logic supports others.
98
  "task_type": "SEQ_2_SEQ_LM"
 
14
  "pipeline": None
15
  }
16
 
17
+ def get_pipeline(model_name, task_type, hf_token):
18
  """
19
  Retrieves a pipeline from cache or loads it if it's new.
20
  """
 
25
  return MODEL_CACHE["pipeline"]
26
 
27
  # Authentication
 
28
  if hf_token:
29
  HfFolder.save_token(hf_token)
30
 
 
51
  model_id = config['model_name']
52
  text = config['text']
53
  task_type = config['task_type']
54
+ hf_token = config['hf_token']
55
  pipe = get_pipeline(model_id, task_type)
56
 
57
  if task_type == 'TOKEN_CLS':
 
92
  config = {
93
  "text": data['text'],
94
  "model_name": data.get('model_name', "facebook/bart-large-cnn"),
95
+ "hf_token": data.get('hf_token',os.getenv("HF_TOKEN"),'')
96
  # We force this for the specific summarization UI,
97
  # but the backend logic supports others.
98
  "task_type": "SEQ_2_SEQ_LM"