File size: 34,866 Bytes
95f8934 24dd073 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 db91874 95f8934 24dd073 95f8934 db91874 95f8934 db91874 95f8934 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 |
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "datasets>=2.14.0",
# "accelerate>=0.24.0",
# "huggingface-hub",
# "pillow>=12.0.0",
# "jiwer>=3.0.0",
# "tqdm>=4.65.0",
# "transformers @ git+https://github.com/baptiste-aubertin/transformers.git@main",
# "trackio",
# ]
# ///
"""
Fine-tune LightOnOCR on OCR datasets.
LightOnOCR is an end-to-end trainable vision-language model specifically designed for OCR tasks.
This script enables fine-tuning on custom datasets for improved performance on specific domains,
languages, or document types.
Examples:
# Basic fine-tuning on IAM handwriting dataset
uv run lightonocr-finetune.py \
--dataset-id HuggingFaceM4/FineVision \
--subset iam \
--output-dir ./lightonocr-iam \
--epochs 2
# Fine-tune with frozen language model to save memory
uv run lightonocr-finetune.py \
--dataset-id HuggingFaceM4/FineVision \
--subset olmOCR-mix-0225-documents \
--output-dir ./lightonocr-docs \
--freeze-language \
--batch-size 8
# Stream large datasets to reduce memory usage
uv run lightonocr-finetune.py \
--dataset-id HuggingFaceM4/FineVision \
--subset olmOCR-mix-0225-books \
--output-dir ./lightonocr-books \
--streaming \
--shuffle-buffer-size 10000 \
--max-train-samples 5000 # Will auto-calculate max-steps
# Push to Hub with evaluation metrics
uv run lightonocr-finetune.py \
--dataset-id HuggingFaceM4/FineVision \
--subset iam \
--hub-model-id username/lightonocr-iam \
--push-to-hub \
--eval-samples 100
# Run on HF Jobs with GPU and streaming
hf jobs run --gpu l4x1 \
uv run lightonocr-finetune.py \
--dataset-id custom/large-ocr-dataset \
--output-dir ./custom-ocr \
--streaming \
--epochs 3
"""
import argparse
import logging
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
import torch
from datasets import load_dataset
from huggingface_hub import login
from jiwer import cer, wer
from tqdm import tqdm
from transformers import (
AutoProcessor,
LightOnOCRForConditionalGeneration,
Trainer,
TrainingArguments,
)
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
os.environ["HF_XET_HIGH_PERFORMANCE"] = "1"
# Constants for the assistant pattern in chat template
ASSISTANT_START_PATTERN = [151645, 1699, 151644, 77091, 1699]
DEFAULT_MAX_LENGTH = 1024
DEFAULT_LONGEST_EDGE = 700
class OCRDataCollator:
"""Data collator for OCR fine-tuning."""
def __init__(
self,
processor,
max_length=DEFAULT_MAX_LENGTH,
longest_edge=DEFAULT_LONGEST_EDGE,
):
self.processor = processor
self.max_length = max_length
self.longest_edge = longest_edge
def __call__(self, examples):
batch_messages = []
batch_images = []
for example in examples:
example_images = example["images"]
example_texts = example["texts"]
# Validate single image/text per example
if len(example_images) != 1 or len(example_texts) != 1:
logger.warning(
f"Skipping example with {len(example_images)} images and {len(example_texts)} texts"
)
continue
image = example_images[0].convert("RGB")
batch_images.append(image)
# Extract assistant text from conversation
conversation = example_texts[0]
assistant_text = conversation.get("assistant", "").strip()
messages = [
{"role": "user", "content": [{"type": "image"}]},
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_text}],
},
]
batch_messages.append(messages)
if not batch_images:
return None
# Apply chat template
texts = [
self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
for messages in batch_messages
]
# Process inputs
inputs = self.processor(
text=texts,
images=batch_images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
size={"longest_edge": self.longest_edge},
)
# Create labels (mask prompt, train only on assistant response)
labels = inputs["input_ids"].clone()
pad_token_id = self.processor.tokenizer.pad_token_id
for i in range(len(labels)):
full_ids = inputs["input_ids"][i].tolist()
# Find where assistant content starts
assistant_content_start = None
# Try the standard pattern: <|im_end|>\n<|im_start|>assistant\n
for idx in range(len(full_ids) - len(ASSISTANT_START_PATTERN)):
if (
full_ids[idx : idx + len(ASSISTANT_START_PATTERN)]
== ASSISTANT_START_PATTERN
):
assistant_content_start = idx + len(ASSISTANT_START_PATTERN)
break
if assistant_content_start is None:
# Some samples may not have the exact pattern - this is expected
# The model will train on samples where the pattern is found
labels[i, :] = -100
else:
# Mask everything first
labels[i, :] = -100
# Unmask from assistant content start to end
for idx in range(assistant_content_start, len(full_ids)):
if full_ids[idx] == pad_token_id:
break
labels[i, idx] = inputs["input_ids"][i, idx]
# Mask padding tokens
labels[i, inputs["input_ids"][i] == pad_token_id] = -100
inputs["labels"] = labels
# Convert to proper dtype
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
return inputs
def evaluate_model(
model,
processor,
dataset,
num_samples: int = 50,
batch_size: int = 8,
device: str = "cuda",
description: str = "Model",
is_streaming: bool = False,
) -> Dict[str, float]:
"""
Evaluate model on dataset and compute OCR metrics.
Args:
model: The model to evaluate
processor: The processor for the model
dataset: Dataset to evaluate on (can be streaming or regular)
num_samples: Number of samples to evaluate
batch_size: Batch size for evaluation
device: Device to run evaluation on
description: Description for logging
is_streaming: Whether the dataset is a streaming dataset
Returns:
Dictionary with CER, WER, and perfect match count
"""
model.eval()
predictions = []
ground_truths = []
logger.info(f"Evaluating {description} on {num_samples} samples...")
# Handle streaming datasets differently
if is_streaming:
# For streaming datasets, we take the first num_samples
samples_processed = 0
batch_samples = []
for sample in tqdm(dataset, total=num_samples, desc="Evaluating"):
if samples_processed >= num_samples:
break
batch_samples.append(sample)
samples_processed += 1
# Process when we have a full batch or reached the end
if len(batch_samples) == batch_size or samples_processed == num_samples:
batch_images = [[s["images"][0]] for s in batch_samples]
batch_ground_truths = [
s["texts"][0]["assistant"].strip() for s in batch_samples
]
# Prepare inputs
messages = [{"role": "user", "content": [{"type": "image"}]}]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
texts = [text] * len(batch_images)
inputs = processor(
text=texts,
images=batch_images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=DEFAULT_MAX_LENGTH,
size={"longest_edge": DEFAULT_LONGEST_EDGE},
).to(device)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
# Generate predictions
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=512, do_sample=True
)
input_length = inputs["input_ids"].shape[1]
generated_ids = outputs[:, input_length:]
batch_predictions = processor.batch_decode(
generated_ids, skip_special_tokens=True
)
batch_predictions = [p.strip() for p in batch_predictions]
predictions.extend(batch_predictions)
ground_truths.extend(batch_ground_truths)
batch_samples = []
else:
# Original non-streaming evaluation
for start_idx in tqdm(range(0, min(num_samples, len(dataset)), batch_size)):
end_idx = min(start_idx + batch_size, num_samples, len(dataset))
batch_samples = [dataset[i] for i in range(start_idx, end_idx)]
batch_images = [[s["images"][0]] for s in batch_samples]
batch_ground_truths = [
s["texts"][0]["assistant"].strip() for s in batch_samples
]
# Prepare inputs
messages = [{"role": "user", "content": [{"type": "image"}]}]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
texts = [text] * len(batch_images)
inputs = processor(
text=texts,
images=batch_images,
return_tensors="pt",
padding=True,
truncation=True,
max_length=DEFAULT_MAX_LENGTH,
size={"longest_edge": DEFAULT_LONGEST_EDGE},
).to(device)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
# Generate predictions
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True)
input_length = inputs["input_ids"].shape[1]
generated_ids = outputs[:, input_length:]
batch_predictions = processor.batch_decode(
generated_ids, skip_special_tokens=True
)
batch_predictions = [p.strip() for p in batch_predictions]
predictions.extend(batch_predictions)
ground_truths.extend(batch_ground_truths)
# Compute metrics
cer_score = cer(ground_truths, predictions) * 100
wer_score = wer(ground_truths, predictions) * 100
perfect_matches = sum(
1 for pred, gt in zip(predictions, ground_truths) if pred == gt
)
actual_samples = len(predictions)
logger.info(
f"CER: {cer_score:.2f}% | WER: {wer_score:.2f}% | Perfect: {perfect_matches}/{actual_samples}"
)
# Show a few examples
for i in range(min(3, len(predictions))):
match = "✅" if predictions[i] == ground_truths[i] else "❌"
logger.info(
f"{match} Sample {i + 1}: '{predictions[i][:50]}...' vs '{ground_truths[i][:50]}...'"
)
return {
"cer": cer_score,
"wer": wer_score,
"perfect_matches": perfect_matches,
"total_samples": actual_samples,
}
def create_model_card_content(
model_id: str,
dataset_id: str,
subset: Optional[str],
base_metrics: Dict[str, float],
finetuned_metrics: Dict[str, float],
training_args: TrainingArguments,
freeze_config: Dict[str, bool],
) -> str:
"""Generate model card content with training details and metrics."""
# Calculate improvements
cer_improvement = base_metrics["cer"] - finetuned_metrics["cer"]
wer_improvement = base_metrics["wer"] - finetuned_metrics["wer"]
perfect_improvement = (
finetuned_metrics["perfect_matches"] - base_metrics["perfect_matches"]
)
# Determine which components were frozen
frozen_components = [comp for comp, is_frozen in freeze_config.items() if is_frozen]
frozen_str = (
", ".join(frozen_components) if frozen_components else "None (full fine-tuning)"
)
dataset_str = f"{dataset_id}/{subset}" if subset else dataset_id
content = f"""---
license: mit
tags:
- vision
- ocr
- document-understanding
- transformers
base_model: lightonai/LightOnOCR-1B-1025
datasets:
- {dataset_id}
metrics:
- cer
- wer
library_name: transformers
---
# {model_id.split("/")[-1]}
This model is a fine-tuned version of [LightOnOCR-1B-1025](https://huggingface.co/lightonai/LightOnOCR-1B-1025) on the {dataset_str} dataset.
## Model Description
LightOnOCR is an end-to-end trainable vision-language model specifically designed for OCR tasks. This fine-tuned version has been optimized for improved performance on the target dataset.
## Training Details
### Dataset
- **Source**: {dataset_str}
- **Training samples**: {training_args.num_train_epochs * (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}
- **Validation samples**: Used for model selection
### Training Configuration
- **Epochs**: {training_args.num_train_epochs}
- **Batch size**: {training_args.per_device_train_batch_size} (effective: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps})
- **Learning rate**: {training_args.learning_rate}
- **Frozen components**: {frozen_str}
- **Hardware**: GPU with mixed precision (bf16)
## Evaluation Results
### Performance Comparison
| Metric | Base Model | Fine-tuned | Improvement |
|--------|------------|------------|-------------|
| **CER (%)** | {base_metrics["cer"]:.2f} | {finetuned_metrics["cer"]:.2f} | {cer_improvement:+.2f} |
| **WER (%)** | {base_metrics["wer"]:.2f} | {finetuned_metrics["wer"]:.2f} | {wer_improvement:+.2f} |
| **Perfect Matches** | {base_metrics["perfect_matches"]}/{base_metrics["total_samples"]} | {finetuned_metrics["perfect_matches"]}/{finetuned_metrics["total_samples"]} | {perfect_improvement:+d} |
*Lower is better for CER and WER. Evaluation performed on {finetuned_metrics["total_samples"]} test samples.*
## Usage
```python
from transformers import AutoProcessor, LightOnOCRForConditionalGeneration
from PIL import Image
import torch
# Load model and processor
model = LightOnOCRForConditionalGeneration.from_pretrained(
"{model_id}",
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("{model_id}")
# Prepare image
image = Image.open("your_image.jpg").convert("RGB")
# Create prompt
messages = [
{{"role": "user", "content": [{{"type": "image"}}]}}
]
# Process and generate
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[[image]],
return_tensors="pt",
max_length=1024
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
generated_text = processor.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print(generated_text)
```
## Training Script
This model was trained using the UV Scripts training pipeline. To reproduce or further fine-tune:
```bash
uv run https://huggingface.co/datasets/uv-scripts/transformers-training/raw/main/lightonocr-finetune.py \\
--dataset-id {dataset_id} \\
{"--subset " + subset if subset else ""} \\
--output-dir ./model \\
--epochs {training_args.num_train_epochs}
```
## Citation
If you use this model, please cite:
```bibtex
@misc{{lightonocr2024,
title={{LightOnOCR: End-to-End Trainable OCR Model}},
author={{LightOn AI}},
year={{2024}},
url={{https://huggingface.co/blog/lightonai/lightonocr}}
}}
```
## License
This model is released under the MIT license.
---
*Generated on {datetime.now().strftime("%Y-%m-%d")} using [UV Scripts](https://huggingface.co/uv-scripts)*
"""
return content
def main():
parser = argparse.ArgumentParser(
description="Fine-tune LightOnOCR on OCR datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Dataset arguments
parser.add_argument(
"--dataset-id",
type=str,
default="HuggingFaceM4/FineVision",
help="HuggingFace dataset ID",
)
parser.add_argument(
"--subset",
type=str,
default="iam",
choices=["iam", "olmOCR-mix-0225-books", "olmOCR-mix-0225-documents"],
help="Dataset subset to use (for FineVision)",
)
parser.add_argument(
"--train-split",
type=str,
default="train[:85%]",
help="Training split specification",
)
parser.add_argument(
"--val-split",
type=str,
default="train[85%:95%]",
help="Validation split specification",
)
parser.add_argument(
"--test-split", type=str, default="train[95%:]", help="Test split specification"
)
# Streaming arguments
parser.add_argument(
"--streaming",
action="store_true",
help="Use dataset streaming to reduce memory usage (Note: uses full training set, ignores train-split percentages)",
)
parser.add_argument(
"--shuffle-buffer-size",
type=int,
default=10000,
help="Buffer size for shuffling when using streaming (default: 10000)",
)
parser.add_argument(
"--max-train-samples",
type=int,
help="Maximum number of training samples when streaming (useful for quick experiments)",
)
# Model arguments
parser.add_argument(
"--model-id",
type=str,
default="lightonai/LightOnOCR-1B-1025",
help="Base model ID",
)
parser.add_argument(
"--freeze-vision", action="store_true", help="Freeze vision encoder"
)
parser.add_argument(
"--freeze-language", action="store_true", help="Freeze language model"
)
parser.add_argument(
"--freeze-projection",
action="store_true",
help="Freeze vision projection layer",
)
# Training arguments
parser.add_argument(
"--output-dir", type=str, required=True, help="Directory to save the model"
)
parser.add_argument(
"--epochs", type=int, default=2, help="Number of training epochs"
)
parser.add_argument(
"--batch-size", type=int, default=4, help="Training batch size per device"
)
parser.add_argument(
"--gradient-accumulation",
type=int,
default=4,
help="Gradient accumulation steps",
)
parser.add_argument(
"--learning-rate", type=float, default=6e-5, help="Learning rate"
)
parser.add_argument(
"--warmup-steps", type=int, default=10, help="Number of warmup steps"
)
parser.add_argument(
"--eval-steps", type=int, default=50, help="Evaluation interval (in steps)"
)
parser.add_argument(
"--save-steps",
type=int,
default=500,
help="Save checkpoint interval (in steps)",
)
parser.add_argument(
"--max-length", type=int, default=1024, help="Maximum sequence length"
)
parser.add_argument(
"--longest-edge", type=int, default=700, help="Longest edge for image resizing"
)
# Evaluation arguments
parser.add_argument(
"--eval-samples", type=int, default=100, help="Number of samples for evaluation"
)
parser.add_argument(
"--eval-batch-size", type=int, default=8, help="Batch size for evaluation"
)
parser.add_argument(
"--skip-base-eval", action="store_true", help="Skip base model evaluation"
)
# Hub arguments
parser.add_argument(
"--hub-model-id", type=str, help="HuggingFace Hub model ID for pushing"
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Push model to HuggingFace Hub"
)
parser.add_argument("--hf-token", type=str, help="HuggingFace API token")
parser.add_argument(
"--private", action="store_true", help="Make the model private on Hub"
)
# Other arguments
parser.add_argument(
"--max-samples", type=int, help="Limit number of training samples (for testing)"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--max-steps",
type=int,
help="Maximum number of training steps (auto-calculated for streaming if not specified)"
)
args = parser.parse_args()
# Check GPU availability
if not torch.cuda.is_available():
logger.error("CUDA is not available. This script requires a GPU.")
logger.info("To run on HF Jobs with GPU:")
logger.info(
f"hf jobs run --gpu l4x1 uv run {__file__} --dataset-id {args.dataset_id} --output-dir {args.output_dir}"
)
sys.exit(1)
device = "cuda"
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
# Set environment variables for better performance
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
torch.set_float32_matmul_precision("high")
# Login to HuggingFace if needed
if args.push_to_hub:
token = args.hf_token or os.environ.get("HF_TOKEN")
if token:
login(token=token)
else:
logger.error("HF_TOKEN required for push_to_hub")
sys.exit(1)
# Load dataset
logger.info(f"Loading dataset: {args.dataset_id}/{args.subset}")
if args.streaming:
logger.info("Using streaming mode for dataset loading")
# For streaming, we can only use "train" split, not percentage-based splits
# Load the full training set in streaming mode
train_ds = load_dataset(
args.dataset_id, args.subset, split="train", streaming=True
)
# For validation and test, we need to load a subset of the data
# We'll use the last 15% of the data for validation and test
# Load the full dataset for splitting into val/test
full_ds = load_dataset(args.dataset_id, args.subset, split="train")
total_size = len(full_ds)
# Calculate split indices
train_end = int(0.85 * total_size)
val_end = int(0.95 * total_size)
# Create validation and test splits
val_ds = full_ds.select(range(train_end, val_end))
test_ds = full_ds.select(range(val_end, total_size))
# Clean up the full dataset to save memory
del full_ds
# Apply shuffling with buffer for streaming dataset
train_ds = train_ds.shuffle(
seed=args.seed, buffer_size=args.shuffle_buffer_size
)
# Limit samples if requested (for streaming)
if args.max_samples or args.max_train_samples:
max_samples = args.max_samples or args.max_train_samples
train_ds = train_ds.take(max_samples)
logger.info(f"Limited training to {max_samples} samples (streaming mode)")
logger.info(
f"Dataset loaded - Training: streaming (full train set), Val: {len(val_ds)}, Test: {len(test_ds)}"
)
logger.info(
"Note: When streaming, using full training set. Use --max-train-samples to limit."
)
else:
# Original non-streaming loading
train_ds = load_dataset(args.dataset_id, args.subset, split=args.train_split)
val_ds = load_dataset(args.dataset_id, args.subset, split=args.val_split)
test_ds = load_dataset(args.dataset_id, args.subset, split=args.test_split)
# Limit samples if requested (non-streaming)
if args.max_samples:
train_ds = train_ds.select(range(min(args.max_samples, len(train_ds))))
logger.info(f"Limited training to {len(train_ds)} samples")
logger.info(
f"Dataset sizes - Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}"
)
# Load processor
logger.info(f"Loading processor from {args.model_id}")
processor = AutoProcessor.from_pretrained(args.model_id)
processor.tokenizer.padding_side = "left"
# Load model
logger.info(f"Loading model from {args.model_id}")
model = LightOnOCRForConditionalGeneration.from_pretrained(
args.model_id,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
).to(device)
# Freeze components as requested
freeze_config = {
"vision_encoder": args.freeze_vision,
"language_model": args.freeze_language,
"vision_projection": args.freeze_projection,
}
if args.freeze_vision:
for param in model.model.vision_encoder.parameters():
param.requires_grad = False
logger.info("Vision encoder frozen")
if args.freeze_language:
for param in model.model.language_model.parameters():
param.requires_grad = False
logger.info("Language model frozen")
if args.freeze_projection:
for param in model.model.vision_projection.parameters():
param.requires_grad = False
logger.info("Vision projection frozen")
# Count trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total parameters: {total_params:,}")
logger.info(
f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)"
)
# Evaluate base model
base_metrics = {
"cer": 0.0,
"wer": 0.0,
"perfect_matches": 0,
"total_samples": args.eval_samples,
}
if not args.skip_base_eval:
logger.info("\n" + "=" * 80)
logger.info("EVALUATING BASE MODEL")
logger.info("=" * 80)
base_metrics = evaluate_model(
model,
processor,
test_ds,
num_samples=args.eval_samples,
batch_size=args.eval_batch_size,
device=device,
description="Base model",
is_streaming=False, # Test dataset is never streamed
)
torch.cuda.empty_cache()
# Prepare data collator
data_collator = OCRDataCollator(
processor, max_length=args.max_length, longest_edge=args.longest_edge
)
# Calculate max_steps for streaming datasets
max_steps = None
if args.streaming:
if args.max_steps:
max_steps = args.max_steps
logger.info(f"Using user-specified max_steps: {max_steps}")
else:
# Estimate max_steps based on dataset size and batch configuration
if args.max_train_samples:
# Calculate based on limited samples
steps_per_epoch = args.max_train_samples // (args.batch_size * args.gradient_accumulation)
max_steps = steps_per_epoch * args.epochs
logger.info(f"Calculated max_steps from max_train_samples: {max_steps}")
else:
# Use a default reasonable value
# Approximate based on typical dataset sizes
# Default to 1000 steps per epoch as a reasonable estimate
max_steps = 1000 * args.epochs
logger.warning(
f"Streaming mode: Using default max_steps={max_steps}. "
f"Consider setting --max-steps or --max-train-samples for precise control."
)
# Setup training arguments
# When streaming, use max_steps instead of num_train_epochs
training_args_dict = {
"output_dir": args.output_dir,
"per_device_train_batch_size": args.batch_size,
"per_device_eval_batch_size": args.eval_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation,
"learning_rate": args.learning_rate,
"weight_decay": 0.0,
"logging_steps": 50,
"eval_strategy": "steps",
"eval_steps": args.eval_steps,
"save_strategy": "steps",
"save_steps": args.save_steps,
"save_total_limit": 2,
"load_best_model_at_end": True,
"metric_for_best_model": "eval_loss",
"bf16": True,
"fp16": False,
"remove_unused_columns": False,
"dataloader_pin_memory": False,
"gradient_checkpointing": True,
"optim": "adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
"warmup_steps": args.warmup_steps,
"lr_scheduler_type": "linear",
"push_to_hub": args.push_to_hub,
"hub_model_id": args.hub_model_id,
"hub_private_repo": args.private,
}
# Add either max_steps or num_train_epochs based on streaming mode
if args.streaming:
training_args_dict["max_steps"] = max_steps
# Still set num_train_epochs for model card generation
training_args_dict["num_train_epochs"] = args.epochs
else:
training_args_dict["num_train_epochs"] = args.epochs
training_args = TrainingArguments(**training_args_dict)
# Use smaller validation set for faster evaluation
val_ds_small = val_ds.select(range(min(100, len(val_ds))))
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds_small,
data_collator=data_collator,
)
# Train
logger.info("\n" + "=" * 80)
logger.info("STARTING TRAINING")
logger.info("=" * 80)
if args.streaming:
logger.info(
f"Training samples: streaming mode (max: {args.max_train_samples or 'unlimited'})"
)
logger.info(f"Max steps: {max_steps}")
else:
logger.info(f"Training samples: {len(train_ds)}")
logger.info(f"Validation samples: {len(val_ds_small)}")
logger.info(f"Effective batch size: {args.batch_size * args.gradient_accumulation}")
trainer.train()
# Save model
logger.info("Saving model and processor...")
trainer.save_model(args.output_dir)
processor.save_pretrained(args.output_dir)
# Evaluate fine-tuned model
logger.info("\n" + "=" * 80)
logger.info("EVALUATING FINE-TUNED MODEL")
logger.info("=" * 80)
finetuned_metrics = evaluate_model(
model,
processor,
test_ds,
num_samples=args.eval_samples,
batch_size=args.eval_batch_size,
device=device,
description="Fine-tuned model",
is_streaming=False, # Test dataset is never streamed
)
# Show comparison
if not args.skip_base_eval:
logger.info("\n" + "=" * 80)
logger.info("PERFORMANCE COMPARISON")
logger.info("=" * 80)
logger.info(
f"{'Metric':<20} {'Base':<12} {'Fine-tuned':<12} {'Improvement':<12}"
)
logger.info("-" * 56)
logger.info(
f"{'CER (%)':<20} {base_metrics['cer']:<12.2f} {finetuned_metrics['cer']:<12.2f} {base_metrics['cer'] - finetuned_metrics['cer']:+.2f}"
)
logger.info(
f"{'WER (%)':<20} {base_metrics['wer']:<12.2f} {finetuned_metrics['wer']:<12.2f} {base_metrics['wer'] - finetuned_metrics['wer']:+.2f}"
)
logger.info(
f"{'Perfect Matches':<20} {base_metrics['perfect_matches']:<12} {finetuned_metrics['perfect_matches']:<12} {finetuned_metrics['perfect_matches'] - base_metrics['perfect_matches']:+d}"
)
logger.info("=" * 80)
# Create and save model card
if args.hub_model_id or args.push_to_hub:
model_id = args.hub_model_id or f"{args.output_dir.split('/')[-1]}"
logger.info("Creating model card with metrics...")
model_card_content = create_model_card_content(
model_id=model_id,
dataset_id=args.dataset_id,
subset=args.subset,
base_metrics=base_metrics,
finetuned_metrics=finetuned_metrics,
training_args=training_args,
freeze_config=freeze_config,
)
# Save model card
model_card_path = Path(args.output_dir) / "README.md"
model_card_path.write_text(model_card_content)
logger.info(f"Model card saved to {model_card_path}")
if args.push_to_hub:
logger.info(f"Pushing model to Hub: {args.hub_model_id}")
trainer.push_to_hub()
logger.info(
f"✅ Model available at: https://huggingface.co/{args.hub_model_id}"
)
logger.info("\n✅ Training complete!")
logger.info(f"Model saved to: {args.output_dir}")
# Print example command for inference
logger.info("\n" + "=" * 80)
logger.info("To use the fine-tuned model:")
logger.info("=" * 80)
logger.info(f"""
from transformers import AutoProcessor, LightOnOCRForConditionalGeneration
from PIL import Image
model = LightOnOCRForConditionalGeneration.from_pretrained("{args.output_dir}")
processor = AutoProcessor.from_pretrained("{args.output_dir}")
# ... rest of inference code
""")
if __name__ == "__main__":
if len(sys.argv) == 1:
print("LightOnOCR Fine-tuning Script\n")
print("Examples:")
print(" # Basic fine-tuning:")
print(
" uv run lightonocr-finetune.py --dataset-id HuggingFaceM4/FineVision --subset iam --output-dir ./model\n"
)
print(" # With frozen components:")
print
" uv run lightonocr-finetune.py --freeze-language --output-dir ./model\n"
)
print(" # Stream large datasets (memory-efficient):")
print(
" uv run lightonocr-finetune.py --streaming --shuffle-buffer-size 10000 --output-dir ./model\n"
)
print(" # Push to Hub:")
print(
" uv run lightonocr-finetune.py --hub-model-id username/model --push-to-hub\n"
)
print(" # Run on HF Jobs:")
print(
" hf jobs run --gpu l4x1 uv run lightonocr-finetune.py --streaming --output-dir ./model"
)
sys.exit(0)
main()
|