abrakjamson commited on
Commit
e3eaad2
·
1 Parent(s): 838d0f7

correction

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -29,6 +29,7 @@ global isModelDefined
29
  isModelDefined = False
30
 
31
  def defineModel():
 
32
  global model
33
  global isModelDefined
34
  global cuda
@@ -397,6 +398,7 @@ def train_model_persona(positive_text, negative_text):
397
  positive_list,
398
  negative_list,
399
  output_suffixes)
 
400
  global model
401
  model.reset()
402
  output_model = ControlVector.train(model, tokenizer, dataset)
@@ -423,6 +425,7 @@ def train_model_facts(positive_text, negative_text):
423
  negative_text,
424
  fact_suffixes
425
  )
 
426
  global model
427
  model.reset()
428
  output_model = ControlVector.train(model, tokenizer, dataset)
 
29
  isModelDefined = False
30
 
31
  def defineModel():
32
+ """ Manging the control flow of this to support ZeroGPU behavior"""
33
  global model
34
  global isModelDefined
35
  global cuda
 
398
  positive_list,
399
  negative_list,
400
  output_suffixes)
401
+ defineModel()
402
  global model
403
  model.reset()
404
  output_model = ControlVector.train(model, tokenizer, dataset)
 
425
  negative_text,
426
  fact_suffixes
427
  )
428
+ defineModel()
429
  global model
430
  model.reset()
431
  output_model = ControlVector.train(model, tokenizer, dataset)