Update file trainer.cli.py
Browse files- trainer.cli.py +13 -3
trainer.cli.py
CHANGED
|
@@ -17,7 +17,8 @@ parser = ArgumentParser(
|
|
| 17 |
description=''
|
| 18 |
)
|
| 19 |
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
if __name__ == '__main__':
|
|
@@ -42,8 +43,6 @@ if __name__ == '__main__':
|
|
| 42 |
batches, num_batches = dataset.batch(ids)
|
| 43 |
config.trainer.num_batches = num_batches
|
| 44 |
|
| 45 |
-
print(f"batches: {num_batches}")
|
| 46 |
-
|
| 47 |
|
| 48 |
model = Model(config.model)
|
| 49 |
wandb = Wandb(config.wandb)
|
|
@@ -52,6 +51,17 @@ if __name__ == '__main__':
|
|
| 52 |
config.trainer.wandb = wandb
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
trainer = Trainer(config.trainer)
|
| 57 |
trainer.train(batches)
|
|
|
|
|
|
|
|
|
| 17 |
description=''
|
| 18 |
)
|
| 19 |
|
| 20 |
+
import torch
|
| 21 |
+
import sys
|
| 22 |
|
| 23 |
|
| 24 |
if __name__ == '__main__':
|
|
|
|
| 43 |
batches, num_batches = dataset.batch(ids)
|
| 44 |
config.trainer.num_batches = num_batches
|
| 45 |
|
|
|
|
|
|
|
| 46 |
|
| 47 |
model = Model(config.model)
|
| 48 |
wandb = Wandb(config.wandb)
|
|
|
|
| 51 |
config.trainer.wandb = wandb
|
| 52 |
|
| 53 |
|
| 54 |
+
# Get the total amount of VRAM allocated by PyTorch
|
| 55 |
+
vram_allocated = torch.cuda.memory_allocated()
|
| 56 |
+
|
| 57 |
+
# Get the size of the model object
|
| 58 |
+
model_size = sys.getsizeof(model)
|
| 59 |
+
|
| 60 |
+
print(f"Total VRAM allocated by PyTorch: {vram_allocated} bytes")
|
| 61 |
+
print(f"Size of model object: {model_size} bytes")
|
| 62 |
+
|
| 63 |
|
| 64 |
trainer = Trainer(config.trainer)
|
| 65 |
trainer.train(batches)
|
| 66 |
+
|
| 67 |
+
|