SunDou commited on
Commit
fe8e53f
·
verified ·
1 Parent(s): 30c724d

Upload data3/enrich_programming_problems.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data3/enrich_programming_problems.py +171 -0
data3/enrich_programming_problems.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Memory-efficient script to enrich programming_problems.jsonl
4
+ Only loads the exact rows we need from enhanced_dataset.csv
5
+ """
6
+
7
+ import json
8
+ import csv
9
+ from tqdm import tqdm
10
+ import sys
11
+
12
+ def get_needed_original_indices(function_csv, input_jsonl):
13
+ """
14
+ Get the set of original_index values we actually need to look up.
15
+
16
+ Returns:
17
+ Dictionary mapping original_index to list of row_numbers that need it
18
+ """
19
+ print("Step 1: Determining which original_index values we need...")
20
+
21
+ # First, get row_number to original_index mapping from function_dataset_v2
22
+ row_to_original = {}
23
+ with open(function_csv, 'r', encoding='utf-8') as f:
24
+ reader = csv.DictReader(f)
25
+ for i, row in enumerate(tqdm(reader, desc="Reading function_dataset_v2"), start=1):
26
+ try:
27
+ original_index = int(row['original_index'])
28
+ row_to_original[i] = original_index
29
+ except (ValueError, KeyError):
30
+ pass
31
+
32
+ # Next, get the row_numbers from JSONL that we need to enrich
33
+ needed_indices = {}
34
+ with open(input_jsonl, 'r', encoding='utf-8') as f:
35
+ for line in tqdm(f, desc="Reading JSONL", total=22532):
36
+ data = json.loads(line.strip())
37
+ row_number = data.get('row_number')
38
+
39
+ if row_number in row_to_original:
40
+ original_index = row_to_original[row_number]
41
+ if original_index not in needed_indices:
42
+ needed_indices[original_index] = []
43
+ needed_indices[original_index].append(row_number)
44
+
45
+ print(f"Need to look up {len(needed_indices)} unique original_index values")
46
+ print(f"Max index needed: {max(needed_indices.keys())}")
47
+ print(f"Min index needed: {min(needed_indices.keys())}")
48
+
49
+ return row_to_original, needed_indices
50
+
51
+
52
+ def load_needed_metadata(enhanced_csv, needed_indices):
53
+ """
54
+ Load only the needed rows from enhanced_dataset.csv.
55
+
56
+ Args:
57
+ enhanced_csv: Path to enhanced_dataset.csv
58
+ needed_indices: Set of original_index values we need
59
+
60
+ Returns:
61
+ Dictionary mapping original_index to {repo_name, path, language}
62
+ """
63
+ print("\nStep 2: Loading only needed rows from enhanced_dataset.csv...")
64
+ print(f"Looking for {len(needed_indices)} unique indices...")
65
+ print("This will scan the entire file - may take several minutes...")
66
+
67
+ mapping = {}
68
+ needed_remaining = set(needed_indices.keys())
69
+
70
+ with open(enhanced_csv, 'r', encoding='utf-8') as f:
71
+ reader = csv.DictReader(f)
72
+
73
+ for i, row in enumerate(tqdm(reader, desc="Reading enhanced_dataset")):
74
+ # Get the index from various possible column names
75
+ idx = row.get('', row.get('Unnamed: 0.1', row.get('Unnamed: 0')))
76
+ if idx:
77
+ try:
78
+ idx = int(idx)
79
+ if idx in needed_remaining:
80
+ mapping[idx] = {
81
+ 'repo_name': row.get('repo_name', ''),
82
+ 'path': row.get('path', ''),
83
+ 'language': row.get('language', '')
84
+ }
85
+ needed_remaining.remove(idx)
86
+
87
+ # Progress update every 1000 found
88
+ if len(mapping) % 1000 == 0:
89
+ print(f"Found {len(mapping)}/{len(needed_indices)} needed indices...")
90
+
91
+ # Early exit if we found everything
92
+ if len(needed_remaining) == 0:
93
+ print(f"Found all needed indices at row {i}!")
94
+ break
95
+ except (ValueError, KeyError):
96
+ pass
97
+
98
+ print(f"Loaded metadata for {len(mapping)} indices")
99
+ print(f"Missing: {len(needed_indices) - len(mapping)} indices")
100
+
101
+ if needed_remaining:
102
+ print(f"Example missing indices: {list(needed_remaining)[:10]}")
103
+
104
+ return mapping
105
+
106
+
107
+ def enrich_programming_problems(input_jsonl, output_jsonl, metadata_mapping, row_to_original):
108
+ """
109
+ Enrich programming_problems.jsonl with metadata.
110
+ """
111
+ print("\nStep 3: Enriching JSONL file...")
112
+
113
+ matched_count = 0
114
+ unmatched_count = 0
115
+
116
+ with open(input_jsonl, 'r', encoding='utf-8') as f_in, \
117
+ open(output_jsonl, 'w', encoding='utf-8') as f_out:
118
+
119
+ for line in tqdm(f_in, desc="Processing JSONL", total=22532):
120
+ data = json.loads(line.strip())
121
+ row_number = data.get('row_number')
122
+
123
+ if row_number in row_to_original:
124
+ original_index = row_to_original[row_number]
125
+
126
+ if original_index in metadata_mapping:
127
+ enrichment = metadata_mapping[original_index]
128
+ data['metadata']['repo_name'] = enrichment['repo_name']
129
+ data['metadata']['path'] = enrichment['path']
130
+ data['metadata']['language'] = enrichment['language']
131
+ matched_count += 1
132
+ else:
133
+ unmatched_count += 1
134
+ else:
135
+ unmatched_count += 1
136
+
137
+ f_out.write(json.dumps(data, ensure_ascii=False) + '\n')
138
+
139
+ return matched_count, unmatched_count
140
+
141
+
142
+ def main():
143
+ enhanced_csv = 'enhanced_dataset.csv'
144
+ function_csv = 'function_dataset_v2.csv'
145
+ input_jsonl = 'programming_problems.jsonl'
146
+ output_jsonl = 'programming_problems_enriched.jsonl'
147
+
148
+ # Step 1: Determine what we need
149
+ row_to_original, needed_indices = get_needed_original_indices(function_csv, input_jsonl)
150
+
151
+ # Step 2: Load only what we need
152
+ metadata_mapping = load_needed_metadata(enhanced_csv, needed_indices)
153
+
154
+ # Step 3: Enrich the JSONL
155
+ matched, unmatched = enrich_programming_problems(input_jsonl, output_jsonl,
156
+ metadata_mapping, row_to_original)
157
+
158
+ print(f"\n{'='*60}")
159
+ print(f"✅ Enrichment complete!")
160
+ print(f"{'='*60}")
161
+ print(f"Output written to: {output_jsonl}")
162
+ print(f"Matched: {matched}")
163
+ print(f"Unmatched: {unmatched}")
164
+ print(f"Total: {matched + unmatched}")
165
+ print(f"Match rate: {matched / (matched + unmatched) * 100:.1f}%")
166
+
167
+ return 0
168
+
169
+
170
+ if __name__ == '__main__':
171
+ sys.exit(main())