abrakjamson commited on
Commit
838d0f7
·
1 Parent(s): 6b01bd2

correction

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -397,6 +397,7 @@ def train_model_persona(positive_text, negative_text):
397
  positive_list,
398
  negative_list,
399
  output_suffixes)
 
400
  model.reset()
401
  output_model = ControlVector.train(model, tokenizer, dataset)
402
  # Write file to temporary directory returning the path to Gradio for download
@@ -422,7 +423,8 @@ def train_model_facts(positive_text, negative_text):
422
  negative_text,
423
  fact_suffixes
424
  )
425
-
 
426
  output_model = ControlVector.train(model, tokenizer, dataset)
427
  filename = re.sub(r'[ <>:"/\\|?*]', '', positive_text) + '_'
428
  temp_file = tempfile.NamedTemporaryFile(
 
397
  positive_list,
398
  negative_list,
399
  output_suffixes)
400
+ global model
401
  model.reset()
402
  output_model = ControlVector.train(model, tokenizer, dataset)
403
  # Write file to temporary directory returning the path to Gradio for download
 
423
  negative_text,
424
  fact_suffixes
425
  )
426
+ global model
427
+ model.reset()
428
  output_model = ControlVector.train(model, tokenizer, dataset)
429
  filename = re.sub(r'[ <>:"/\\|?*]', '', positive_text) + '_'
430
  temp_file = tempfile.NamedTemporaryFile(