stmasson commited on
Commit
188cdd5
·
verified ·
1 Parent(s): 9c8cf56

Upload scripts/train_qwen3_sft_multitask.py with huggingface_hub

Browse files
scripts/train_qwen3_sft_multitask.py CHANGED
@@ -65,7 +65,7 @@ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "1"))
65
  BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
66
  GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
67
  LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
68
- MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "8192"))
69
 
70
  # LoRA (continuing from DPO adapter)
71
  LORA_R = int(os.environ.get("LORA_R", "32"))
@@ -164,6 +164,17 @@ val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
164
  print(f"Train: {len(train_dataset)} examples")
165
  print(f"Validation: {len(val_dataset)} examples")
166
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Format examples
168
  def format_example(example):
169
  """Format messages to text for training."""
 
65
  BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
66
  GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
67
  LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
68
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
69
 
70
  # LoRA (continuing from DPO adapter)
71
  LORA_R = int(os.environ.get("LORA_R", "32"))
 
164
  print(f"Train: {len(train_dataset)} examples")
165
  print(f"Validation: {len(val_dataset)} examples")
166
 
167
+ # Filter out very long examples to avoid OOM
168
+ def filter_by_length(example):
169
+ """Filter examples that would be too long."""
170
+ total_len = sum(len(m.get('content', '')) for m in example['messages'])
171
+ return total_len < 30000 # ~7500 tokens max
172
+
173
+ print("Filtering long examples...")
174
+ train_dataset = train_dataset.filter(filter_by_length)
175
+ val_dataset = val_dataset.filter(filter_by_length)
176
+ print(f"After filtering - Train: {len(train_dataset)}, Val: {len(val_dataset)}")
177
+
178
  # Format examples
179
  def format_example(example):
180
  """Format messages to text for training."""