Update 3 files
Browse files- /dataset.py
- /model.py
- /trainer.cli.py
- dataset.py +1 -1
- model.py +2 -2
- trainer.cli.py +6 -1
dataset.py
CHANGED
|
@@ -31,7 +31,7 @@ class Dataset:
|
|
| 31 |
|
| 32 |
batches = []
|
| 33 |
for batch in array_reshaped:
|
| 34 |
-
tensor_batch = torch.tensor(batch, dtype=torch.long)
|
| 35 |
batches.append(tensor_batch)
|
| 36 |
|
| 37 |
return batches, num_batches
|
|
|
|
| 31 |
|
| 32 |
batches = []
|
| 33 |
for batch in array_reshaped:
|
| 34 |
+
tensor_batch = torch.tensor(batch, dtype=torch.long).to(GetDevice())
|
| 35 |
batches.append(tensor_batch)
|
| 36 |
|
| 37 |
return batches, num_batches
|
model.py
CHANGED
|
@@ -50,12 +50,12 @@ class Model:
|
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
encoded_ids = tokenizer.encode(seed_text)
|
| 53 |
-
input_ids = torch.tensor(encoded_ids).unsqueeze(0)
|
| 54 |
output = self.model.generate(input_ids, max_length=max_len)
|
| 55 |
|
| 56 |
logits = output[0].tolist()
|
| 57 |
text = tokenizer.decode(logits)
|
| 58 |
-
|
| 59 |
return text
|
| 60 |
|
| 61 |
|
|
|
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
encoded_ids = tokenizer.encode(seed_text)
|
| 53 |
+
input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
|
| 54 |
output = self.model.generate(input_ids, max_length=max_len)
|
| 55 |
|
| 56 |
logits = output[0].tolist()
|
| 57 |
text = tokenizer.decode(logits)
|
| 58 |
+
|
| 59 |
return text
|
| 60 |
|
| 61 |
|
trainer.cli.py
CHANGED
|
@@ -52,4 +52,9 @@ if __name__ == '__main__':
|
|
| 52 |
|
| 53 |
|
| 54 |
trainer = Trainer(config.trainer)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
trainer = Trainer(config.trainer)
|
| 55 |
+
|
| 56 |
+
while True:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
#trainer.train(batches)
|