Upload train_ministral_n8n.py with huggingface_hub
Browse files- train_ministral_n8n.py +24 -2
train_ministral_n8n.py
CHANGED
|
@@ -157,9 +157,31 @@ trainer = SFTTrainer(
|
|
| 157 |
processing_class=tokenizer,
|
| 158 |
)
|
| 159 |
|
| 160 |
-
# Train
|
| 161 |
print("Starting training...")
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
# Save final model
|
| 165 |
print("Saving final model...")
|
|
|
|
| 157 |
processing_class=tokenizer,
|
| 158 |
)
|
| 159 |
|
| 160 |
+
# Train - resume from checkpoint if available on Hub
|
| 161 |
print("Starting training...")
|
| 162 |
+
# Try to resume from Hub checkpoint
|
| 163 |
+
try:
|
| 164 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 165 |
+
files = list_repo_files(OUTPUT_MODEL)
|
| 166 |
+
if "last-checkpoint" in str(files) or "adapter_model.safetensors" in files:
|
| 167 |
+
print(f"Found existing checkpoint on Hub, downloading to resume...")
|
| 168 |
+
# Download checkpoint files
|
| 169 |
+
import os
|
| 170 |
+
os.makedirs("./resume-checkpoint", exist_ok=True)
|
| 171 |
+
for f in ["adapter_model.safetensors", "adapter_config.json", "trainer_state.json", "training_args.bin"]:
|
| 172 |
+
try:
|
| 173 |
+
hf_hub_download(OUTPUT_MODEL, f, local_dir="./resume-checkpoint")
|
| 174 |
+
except:
|
| 175 |
+
pass
|
| 176 |
+
if os.path.exists("./resume-checkpoint/trainer_state.json"):
|
| 177 |
+
trainer.train(resume_from_checkpoint="./resume-checkpoint")
|
| 178 |
+
else:
|
| 179 |
+
trainer.train()
|
| 180 |
+
else:
|
| 181 |
+
trainer.train()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Could not resume from checkpoint: {e}, starting fresh...")
|
| 184 |
+
trainer.train()
|
| 185 |
|
| 186 |
# Save final model
|
| 187 |
print("Saving final model...")
|