ambivalent02 commited on
Commit
e5519c9
·
verified ·
1 Parent(s): 620c4c3

Upload loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loader.py +165 -0
loader.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from datasets import load_dataset
2
+
3
+ # raw_ds = load_dataset("simwit/omni-med-vqa-mini")
4
+ # full_dataset = raw_ds["test"]
5
+ # split = full_dataset.train_test_split(test_size=0.2, seed=42)
6
+ # train_dataset = split["train"]
7
+ # eval_dataset = split["test"]
8
+
9
+ # print("✅ SFT Dataset loaded:")
10
+ # print(f" 📚 Train samples: {len(train_dataset)}")
11
+ # print(f" 🧪 Eval samples: {len(eval_dataset)}")
12
+ # print(f"\n📝 Single Sample: [IMAGE] {train_dataset[0]['question']} {train_dataset[0]['gt_answer']} {train_dataset[0]['image_path']} {list(train_dataset[0].keys())}")
13
+ """
14
+ Convert jsonl with `image` and `conversations` into
15
+ a HuggingFace Dataset that LFM2-VL expects.
16
+ Each sample must contain:
17
+ - image : str (absolute path or relative to repo root)
18
+ - messages: List[Dict] # openai-style
19
+ """
20
+ import json, datasets
21
+ from pathlib import Path
22
+ from typing import List, Dict
23
+ import multiprocessing as mp
24
+ from PIL import Image
25
+ SYSTEM_MSG = "You are a helpful vision-language assistant."
26
+
27
+
28
+ """
29
+ Convert jsonl with `image` and `conversations` into
30
+ a HuggingFace Dataset that works with the medical sample format.
31
+ """
32
+ import json, datasets
33
+ from pathlib import Path
34
+ from typing import List, Dict
35
+ import multiprocessing as mp
36
+ from PIL import Image
37
+ def format_vlm_sample(sample):
38
+ """Format a vlm sample into the expected message format."""
39
+ return [
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "image", "image": sample["image"]},
44
+ {"type": "text", "text": sample["question"]},
45
+ ],
46
+ },
47
+ {"role": "assistant", "content": [{"type": "text", "text": sample["gt_answer"]}]},
48
+ ]
49
+ def jsonl_to_dataset_hf_parallel(jsonl_file: str, image_root: str = "", num_workers: int = None):
50
+ """
51
+ Fixed parallel version that handles None values properly
52
+ """
53
+ if num_workers is None:
54
+ num_workers = 8
55
+ # Load and validate all lines first
56
+ valid_lines = []
57
+ with open(jsonl_file, encoding="utf-8") as f:
58
+ for line_num, line in enumerate(f):
59
+ line = line.strip()
60
+ if line: # Skip empty lines
61
+ try:
62
+ # Quick validation
63
+ rec = json.loads(line)
64
+ if "image" in rec and "conversations" in rec:
65
+ valid_lines.append({"line": line, "image_root": image_root, "line_num": line_num})
66
+ except:
67
+ print(f"Warning: Line {line_num}: Invalid JSON")
68
+ continue
69
+
70
+ print(f"Found {len(valid_lines)} valid lines to process")
71
+
72
+ # Create dataset from valid lines
73
+ raw_dataset = datasets.Dataset.from_list(valid_lines)
74
+
75
+ def process_example_safe(example):
76
+ """Process function that never returns None"""
77
+
78
+ rec = json.loads(example["line"])
79
+ image_path = Path(example["image_root"]) / rec["image"]
80
+
81
+ if not image_path.exists():
82
+ # Return a dummy valid entry instead of None
83
+ return {
84
+ "image": str(image_path.absolute()),
85
+ "question": "dummy",
86
+ "gt_answer": "dummy",
87
+ "valid": False
88
+ }
89
+
90
+ # Extract question and answer
91
+ question = ""
92
+ gt_answer = ""
93
+
94
+ for turn in rec["conversations"]:
95
+ if turn["from"] == "human":
96
+ question = turn["value"].replace("<image>", "").strip()
97
+ elif turn["from"] == "gpt" or turn["from"] == "assistant":
98
+ gt_answer = turn["value"].strip()
99
+ break
100
+
101
+ if not question or not gt_answer:
102
+ return {
103
+ "image": str(image_path.absolute()),
104
+ "question": "dummy",
105
+ "gt_answer": "dummy",
106
+ "valid": False
107
+ }
108
+
109
+ return {
110
+ "image": str(image_path.absolute()),
111
+ "question": question,
112
+ "gt_answer": gt_answer,
113
+ "valid": True
114
+ }
115
+
116
+
117
+
118
+ # Process in parallel
119
+ processed_dataset = raw_dataset.map(
120
+ process_example_safe,
121
+ num_proc=num_workers,
122
+ remove_columns=["line", "image_root", "line_num"],
123
+ desc="Processing medical QA records"
124
+ )
125
+
126
+ # Filter out invalid entries
127
+ valid_dataset = processed_dataset.filter(lambda x: x["valid"])
128
+
129
+ # Remove the 'valid' column
130
+ valid_dataset = valid_dataset.remove_columns(["valid"])
131
+
132
+ print(f"Valid samples after processing: {len(valid_dataset)}")
133
+
134
+ # # Load images sequentially to manage memory
135
+ # def load_image_safe(example):
136
+ # image = Image.open(example["image"])
137
+ # if image.mode != 'RGB':
138
+ # image = image.convert('RGB')
139
+ # example["image"] = image
140
+ # example["image_loaded"] = True
141
+ # return example
142
+
143
+
144
+ # # Load images
145
+ # final_dataset = valid_dataset.map(
146
+ # load_image_safe,
147
+ # desc="Loading images",
148
+ # num_proc=256 # Sequential for image loading
149
+ # )
150
+
151
+ # # Filter out failed image loads
152
+ # final_dataset = valid_dataset.filter(lambda x: x["image_loaded"])
153
+ # final_dataset = final_dataset.remove_columns(["image_loaded"])
154
+
155
+ print(f"✅ Final dataset size: {len(valid_dataset)} medical QA samples")
156
+ return valid_dataset
157
+
158
+
159
+ if __name__ == "__main__":
160
+ # Test the loader
161
+ ds = jsonl_to_dataset_hf_parallel("data/train.jsonl")
162
+ if len(ds) > 0:
163
+ print("Sample:", ds[0].keys())
164
+ print("Question:", ds[0]["question"])
165
+ print("Answer:", ds[0]["gt_answer"])