flpelerin commited on
Commit
6a41ff7
·
1 Parent(s): 7e29145

Update 3 files

Browse files

- /dataset.py
- /model.py
- /trainer.cli.py

Files changed (3) hide show
  1. dataset.py +1 -1
  2. model.py +2 -2
  3. trainer.cli.py +6 -1
dataset.py CHANGED
@@ -31,7 +31,7 @@ class Dataset:
31
 
32
  batches = []
33
  for batch in array_reshaped:
34
- tensor_batch = torch.tensor(batch, dtype=torch.long)#.to(GetDevice())
35
  batches.append(tensor_batch)
36
 
37
  return batches, num_batches
 
31
 
32
  batches = []
33
  for batch in array_reshaped:
34
+ tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
35
  batches.append(tensor_batch)
36
 
37
  return batches, num_batches
model.py CHANGED
@@ -50,12 +50,12 @@ class Model:
50
 
51
  with torch.no_grad():
52
  encoded_ids = tokenizer.encode(seed_text)
53
- input_ids = torch.tensor(encoded_ids).unsqueeze(0)#.to(GetDevice())
54
  output = self.model.generate(input_ids, max_length=max_len)
55
 
56
  logits = output[0].tolist()
57
  text = tokenizer.decode(logits)
58
-
59
  return text
60
 
61
 
 
50
 
51
  with torch.no_grad():
52
  encoded_ids = tokenizer.encode(seed_text)
53
+ input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
54
  output = self.model.generate(input_ids, max_length=max_len)
55
 
56
  logits = output[0].tolist()
57
  text = tokenizer.decode(logits)
58
+
59
  return text
60
 
61
 
trainer.cli.py CHANGED
@@ -52,4 +52,9 @@ if __name__ == '__main__':
52
 
53
 
54
  trainer = Trainer(config.trainer)
55
- trainer.train(batches)
 
 
 
 
 
 
52
 
53
 
54
  trainer = Trainer(config.trainer)
55
+
56
+ while True:
57
+ pass
58
+
59
+
60
+ #trainer.train(batches)