dataset-builder / data3 /generate_programming_problems.py
SunDou's picture
Upload data3/generate_programming_problems.py with huggingface_hub
315961d verified
#!/usr/bin/env python3
"""
Generate programming problems from function_dataset_v2.csv using Gemini API.
Filters by relevance score and controls API cost.
"""
import csv
import json
import os
import sys
import vertexai
from vertexai.generative_models import GenerativeModel
from datetime import datetime
from typing import Dict, Optional, Tuple, List
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
# Configuration
PROJECT_ID = "tangou"
MODEL_NAME = "gemini-2.5-flash-lite" # Using flash model for cost efficiency
MIN_RELEVANCE_SCORE = 1 # Only process functions with score >= 60
MAX_BUDGET_USD = 50.0 # Maximum budget in USD
# Gemini 2.0 Flash pricing (as of Dec 2024)
# https://cloud.google.com/vertex-ai/generative-ai/pricing
INPUT_PRICE_PER_MILLION = 0.1 # Free tier or promotional pricing
OUTPUT_PRICE_PER_MILLION = 0.4 # Free tier or promotional pricing
# If using Gemini 1.5 Flash instead:
# INPUT_PRICE_PER_MILLION = 0.075
# OUTPUT_PRICE_PER_MILLION = 0.30
PROMPT_TEMPLATE = """You are an expert in scientific computing and computational chemistry/biology/physics. Please create a high-quality programming problem inspired by the following code snippet from a real scientific computing project.
The problem should focus on scientific computing concepts such as:
- Numerical algorithms and simulations
- Data analysis and visualization
- Mathematical modeling
- Scientific data processing
- Computational methods in chemistry, biology, or physics
Code snippet for inspiration:
```python
{code}
```
Present your output in two distinct sections:
[Problem Description]
Create a **completely self-contained** problem description that:
- Does NOT directly reference the code snippet above
- Provides all necessary context and background
- Clearly states what needs to be implemented
- Specifies input/output format and constraints
- Is inspired by the scientific computing concepts in the code but creates a NEW, interesting problem
- Assumes common programming knowledge but explains any domain-specific concepts
[Solution]
Provide a comprehensive, **correct** Python solution that:
- Accurately solves the problem described
- Includes clear comments explaining the approach
- Uses appropriate scientific computing libraries (numpy, scipy, etc.) when relevant
- Is complete and runnable
- Follows best practices for scientific computing
Remember: The problem should be INSPIRED by the code, not a direct copy. Create something educational and interesting for scientific computing practitioners."""
class GeminiAPIClient:
"""Client for Gemini API with cost tracking."""
def __init__(self, project_id: str, model_name: str):
"""Initialize Gemini API client.
Args:
project_id: Google Cloud project ID
model_name: Name of the Gemini model to use
"""
vertexai.init(project=project_id)
self.model = GenerativeModel(model_name)
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_requests = 0
self.total_cost = 0.0
self._lock = threading.Lock() # Thread safety for concurrent requests
def generate_content(self, prompt: str) -> Tuple[str, Dict]:
"""Generate content using Gemini API and track usage.
Args:
prompt: The prompt to send to the API
Returns:
Tuple of (response_text, usage_info)
usage_info contains: input_tokens, output_tokens, cost
"""
try:
response = self.model.generate_content(prompt)
usage_metadata = response.usage_metadata
input_tokens = usage_metadata.prompt_token_count
output_tokens = usage_metadata.candidates_token_count
# Calculate cost
input_cost = (input_tokens / 1_000_000) * INPUT_PRICE_PER_MILLION
output_cost = (output_tokens / 1_000_000) * OUTPUT_PRICE_PER_MILLION
request_cost = input_cost + output_cost
# Update totals (thread-safe)
with self._lock:
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.total_requests += 1
self.total_cost += request_cost
usage_info = {
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
'input_cost': input_cost,
'output_cost': output_cost,
'request_cost': request_cost
}
return response.text, usage_info
except Exception as e:
print(f"Error generating content: {e}")
raise
def get_total_usage(self) -> Dict:
"""Get total usage statistics.
Returns:
Dictionary with total usage information
"""
return {
'total_requests': self.total_requests,
'total_input_tokens': self.total_input_tokens,
'total_output_tokens': self.total_output_tokens,
'total_tokens': self.total_input_tokens + self.total_output_tokens,
'total_cost': self.total_cost
}
def print_usage_summary(self):
"""Print a summary of API usage and costs."""
usage = self.get_total_usage()
print("\n" + "="*70)
print("API USAGE SUMMARY")
print("="*70)
print(f"Total Requests: {usage['total_requests']}")
print(f"Total Input Tokens: {usage['total_input_tokens']:,}")
print(f"Total Output Tokens: {usage['total_output_tokens']:,}")
print(f"Total Tokens: {usage['total_tokens']:,}")
print(f"\nTotal Cost: ${usage['total_cost']:.6f}")
print(f"Budget Remaining: ${MAX_BUDGET_USD - usage['total_cost']:.6f}")
print("="*70)
def process_function_dataset(
input_file: str,
output_file: str,
min_score: int = MIN_RELEVANCE_SCORE,
max_budget: float = MAX_BUDGET_USD,
max_samples: Optional[int] = None,
start_from: int = 0,
max_workers: int = 5
):
"""Process function dataset and generate programming problems.
Args:
input_file: Path to function_dataset_v2.csv
output_file: Path to output JSONL file
min_score: Minimum relevance score to process
max_budget: Maximum budget in USD
max_samples: Maximum number of samples to process (None for all)
start_from: Skip first N rows (for resuming)
max_workers: Maximum number of concurrent workers (default: 5)
"""
print(f"Starting programming problem generation...")
print(f"Input: {input_file}")
print(f"Output: {output_file}")
print(f"Min Relevance Score: {min_score}")
print(f"Max Budget: ${max_budget:.2f}")
print(f"Max Workers: {max_workers}")
if max_samples:
print(f"Max Samples: {max_samples}")
print(f"Starting from row: {start_from}")
print()
# Read already processed row numbers from output file
processed_rows = set()
if os.path.exists(output_file):
print(f"Checking existing output file for already processed rows...")
try:
with open(output_file, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line.strip())
if 'row_number' in data:
processed_rows.add(data['row_number'])
except json.JSONDecodeError:
continue
print(f"Found {len(processed_rows)} already processed rows. These will be skipped.")
except Exception as e:
print(f"Warning: Could not read existing output file: {e}")
else:
print(f"No existing output file found. Will create new file.")
print()
# Initialize Gemini client
client = GeminiAPIClient(PROJECT_ID, MODEL_NAME)
# Statistics
total_rows = 0
processed = 0
skipped_low_score = 0
skipped_no_code = 0
skipped_already_processed = 0
errors = 0
# Prepare tasks to process
tasks = []
with open(input_file, 'r', encoding='utf-8') as infile:
reader = csv.DictReader(infile)
for row in reader:
total_rows += 1
# Skip if resuming
if total_rows <= start_from:
continue
# Skip if already processed
if total_rows in processed_rows:
skipped_already_processed += 1
continue
# Check if we've reached max samples
if max_samples and len(tasks) >= max_samples:
break
# Filter by relevance score
try:
relevance_score = int(row.get('relevance_score', 0))
except (ValueError, TypeError):
relevance_score = 0
if relevance_score < min_score:
skipped_low_score += 1
continue
# Get function content
function_content = row.get('function_content', '').strip()
if not function_content or len(function_content) < 50:
skipped_no_code += 1
continue
# Prepare metadata
metadata = {
'original_index': row.get('original_index'),
'function_name': row.get('function_name'),
'repo_name': row.get('repo_name'),
'path': row.get('path'),
'language': row.get('language'),
'relevance_score': relevance_score,
'function_start_line': row.get('function_start_line'),
'function_end_line': row.get('function_end_line'),
}
# Generate prompt
prompt = PROMPT_TEMPLATE.format(code=function_content)
tasks.append({
'row_number': total_rows,
'metadata': metadata,
'prompt': prompt,
'function_content': function_content
})
print(f"Total rows read: {total_rows}")
print(f"Tasks to process: {len(tasks)}")
print(f"Skipped (low score): {skipped_low_score}")
print(f"Skipped (no/short code): {skipped_no_code}")
print(f"\nStarting concurrent processing with {max_workers} workers...\n")
# Define worker function
def process_task(task):
"""Process a single task."""
try:
row_number = task['row_number']
metadata = task['metadata']
prompt = task['prompt']
print(f"Processing row {row_number} (score={metadata['relevance_score']}, func={metadata['function_name']})...", end=' ')
response_text, usage_info = client.generate_content(prompt)
print(f"✓ (${usage_info['request_cost']:.6f}, {usage_info['total_tokens']} tokens)")
# Return result
return {
'success': True,
'data': {
'metadata': metadata,
'prompt': prompt,
'response': response_text,
'usage': usage_info,
'timestamp': datetime.now().isoformat(),
'row_number': row_number
}
}
except Exception as e:
print(f"✗ Error: {e}")
return {
'success': False,
'error': str(e),
'row_number': task['row_number']
}
# Open output file in append mode if resuming
mode = 'a' if start_from > 0 else 'w'
# Process tasks concurrently
with open(output_file, mode, encoding='utf-8') as outfile:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_task = {executor.submit(process_task, task): task for task in tasks}
# Process results as they complete
for future in as_completed(future_to_task):
# Check budget before processing more
if client.total_cost >= max_budget:
print(f"\n⚠️ Budget limit reached (${client.total_cost:.6f} >= ${max_budget:.2f})")
print(f"Cancelling remaining tasks...")
# Cancel pending futures
for f in future_to_task:
f.cancel()
break
result = future.result()
if result['success']:
# Save result
outfile.write(json.dumps(result['data'], ensure_ascii=False) + '\n')
outfile.flush() # Ensure data is written immediately
processed += 1
# Print periodic summary
if processed % 10 == 0:
print(f"\n--- Progress: {processed} problems generated, ${client.total_cost:.6f} spent ---\n")
else:
errors += 1
# Final summary
print("\n" + "="*70)
print("PROCESSING COMPLETE")
print("="*70)
print(f"Total rows read: {total_rows}")
print(f"Successfully processed: {processed}")
print(f"Skipped (low score): {skipped_low_score}")
print(f"Skipped (no/short code): {skipped_no_code}")
print(f"Errors: {errors}")
client.print_usage_summary()
print(f"\nResults saved to: {output_file}")
return processed
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description='Generate programming problems from function dataset using Gemini API'
)
parser.add_argument(
'--input',
default='function_dataset_v2.csv',
help='Input CSV file (default: function_dataset_v2.csv)'
)
parser.add_argument(
'--output',
default='programming_problems.jsonl',
help='Output JSONL file (default: programming_problems.jsonl)'
)
parser.add_argument(
'--min-score',
type=int,
default=MIN_RELEVANCE_SCORE,
help=f'Minimum relevance score (default: {MIN_RELEVANCE_SCORE})'
)
parser.add_argument(
'--max-budget',
type=float,
default=MAX_BUDGET_USD,
help=f'Maximum budget in USD (default: {MAX_BUDGET_USD})'
)
parser.add_argument(
'--max-samples',
type=int,
default=None,
help='Maximum number of samples to process (default: no limit)'
)
parser.add_argument(
'--start-from',
type=int,
default=0,
help='Start from row N (for resuming, default: 0)'
)
parser.add_argument(
'--max-workers',
type=int,
default=10,
help='Maximum number of concurrent workers (default: 10)'
)
args = parser.parse_args()
# Check if input file exists
if not os.path.exists(args.input):
print(f"Error: Input file not found: {args.input}")
sys.exit(1)
try:
process_function_dataset(
input_file=args.input,
output_file=args.output,
min_score=args.min_score,
max_budget=args.max_budget,
max_samples=args.max_samples,
start_from=args.start_from,
max_workers=args.max_workers
)
print("\n✅ Success!")
except KeyboardInterrupt:
print("\n\n⚠️ Interrupted by user. Progress has been saved to output file.")
print(" You can resume by using --start-from <row_number>")
sys.exit(0)
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)