Only run predictions
Browse files
data-processing/granite-ttm.py
CHANGED
|
@@ -94,20 +94,14 @@ def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=
|
|
| 94 |
|
| 95 |
# train predictions
|
| 96 |
|
| 97 |
-
print("+" * 20, "Train
|
| 98 |
-
zeroshot_output = zeroshot_trainer.evaluate(dset_train)
|
| 99 |
-
print(zeroshot_output)
|
| 100 |
-
|
| 101 |
predictions_dict = zeroshot_trainer.predict(dset_train)
|
| 102 |
|
| 103 |
predictions_np_train = predictions_dict.predictions[0]
|
| 104 |
|
| 105 |
# test predictions
|
| 106 |
|
| 107 |
-
print("+" * 20, "Test
|
| 108 |
-
zeroshot_output = zeroshot_trainer.evaluate(dset_test)
|
| 109 |
-
print(zeroshot_output)
|
| 110 |
-
|
| 111 |
predictions_dict = zeroshot_trainer.predict(dset_test)
|
| 112 |
|
| 113 |
predictions_np_test = predictions_dict.predictions[0]
|
|
|
|
| 94 |
|
| 95 |
# train predictions
|
| 96 |
|
| 97 |
+
print("+" * 20, "Train predict zero-shot", "+" * 20)
|
|
|
|
|
|
|
|
|
|
| 98 |
predictions_dict = zeroshot_trainer.predict(dset_train)
|
| 99 |
|
| 100 |
predictions_np_train = predictions_dict.predictions[0]
|
| 101 |
|
| 102 |
# test predictions
|
| 103 |
|
| 104 |
+
print("+" * 20, "Test predict zero-shot", "+" * 20)
|
|
|
|
|
|
|
|
|
|
| 105 |
predictions_dict = zeroshot_trainer.predict(dset_test)
|
| 106 |
|
| 107 |
predictions_np_test = predictions_dict.predictions[0]
|