code_repo_finetuning / scripts /02_generate_data.py
tensense's picture
Upload folder using huggingface_hub
4e909c7 verified
"""
修复版训练数据生成器
核心改进:
1. 直接基于代码内容生成准确的问答对
2. 不依赖LLM生成(避免循环依赖)
3. 使用模板化方法确保数据质量
4. 优化项目概览问题,使其更具项目特色
"""
import json
import yaml
import random
from pathlib import Path
from typing import List, Dict, Any
from dataclasses import dataclass, field # <--- 修复: dataclass 位于 dataclasses 模块
import re
from collections import defaultdict
@dataclass
class TrainingSample:
"""训练样本"""
conversations: List[Dict[str, str]]
metadata: Dict[str, Any]
class FixedDataGenerator:
"""修复版数据生成器 - 基于规则和模板"""
def __init__(self, config_path: str = "../config/default_config.yaml",
analysis_path: str = "../data/repository_analysis.json"):
with open(config_path, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
try:
with open(analysis_path, 'r', encoding='utf-8') as f:
self.analysis_data = json.load(f)
except FileNotFoundError:
print(f"❌ 警告: 找不到分析文件 {analysis_path}。请先运行分析器。")
self.analysis_data = {'code_elements': [], 'project_context': {}}
self.code_elements = self.analysis_data.get('code_elements', [])
self.project_context = self.analysis_data.get('project_context', {})
self.project_name = self.project_context.get('project_name', 'Laddr')
self.training_samples = []
def generate_training_data(self):
"""生成训练数据"""
print(f"Generating training data for {self.project_name}...")
# 1. 代码解释任务(基于docstring + 代码结构)
print("Generating code explanation samples...")
self._generate_code_explanation_samples()
# 2. API使用示例(基于函数签名 + docstring)
print("Generating API usage samples...")
self._generate_api_usage_samples()
# 3. 项目概览问答(基于统计和结构信息)
print("Generating project overview samples...")
self._generate_project_overview_samples()
# 4. 代码定位任务("在哪个文件中...")
print("Generating code location samples...")
self._generate_code_location_samples()
print(f"Total samples generated: {len(self.training_samples)}")
def _generate_code_explanation_samples(self):
"""生成代码解释样本 - 基于真实代码和docstring"""
# 选择有docstring的元素
candidates = [e for e in self.code_elements
if e.get('docstring') and len(e.get('code', '')) > 50]
for element in candidates[:300]: # 增加数量限制
name = element['name']
docstring = element['docstring']
filepath = element['filepath']
element_type = element['type']
code = element.get('code', '')
# 提取函数签名
signature = self._extract_signature(code, element_type)
# 问题模板
questions = [
f"请解释 {self.project_name} 中 `{name}` 的作用。",
f"{self.project_name} 的 `{name}` 是做什么的?",
f"在 {self.project_name} 项目中,`{name}` 有什么功能?",
]
question = random.choice(questions)
# 构建高质量答案(基于真实信息)
answer_parts = []
# 1. 基本信息
answer_parts.append(f"`{name}` 是 {self.project_name} 项目中的一个 {self._type_to_cn(element_type)},位于 `{filepath}`。")
# 2. 功能描述(来自docstring)
if docstring:
# 清理docstring
clean_doc = self._clean_docstring(docstring)
answer_parts.append(f"\n**功能描述**:\n{clean_doc}")
# 3. 函数签名(如果有)
if signature:
answer_parts.append(f"\n**函数签名**:\n```python\n{signature}\n```")
# 4. 参数说明(如果有)
params = element.get('parameters', [])
if params and len(params) > 0:
param_desc = "\n**参数**:\n"
for param in params[:5]: # 最多5个参数
param_name = param.get('name', 'unknown')
param_type = param.get('type', 'Any')
# 尝试从 docstring 中提取参数描述,如果没有则使用类型
param_desc_from_doc = self._extract_param_desc(docstring, param_name)
if param_desc_from_doc:
param_info = f"- `{param_name}` ({param_type}): {param_desc_from_doc}\n"
else:
param_info = f"- `{param_name}` ({param_type})\n"
param_desc += param_info
answer_parts.append(param_desc)
# 5. 返回值(如果有)
return_type = element.get('return_type')
if return_type:
answer_parts.append(f"\n**返回值**:`{return_type}`")
answer = ''.join(answer_parts)
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
],
metadata={
"task_type": "code_explanation",
"element_name": name,
"filepath": filepath
}
))
def _generate_api_usage_samples(self):
"""生成API使用示例 - 基于函数签名"""
# 选择公共函数/方法
candidates = [e for e in self.code_elements
if e['type'] in ['function', 'method']
and not e['name'].startswith('_') # 排除私有方法
and e.get('parameters')]
for element in candidates[:150]: # 增加数量限制
name = element['name']
params = element.get('parameters', [])
filepath = element['filepath']
docstring = element.get('docstring', '')
question = f"如何在 {self.project_name} 中使用 `{name}` 函数?"
# 构建使用示例
answer_parts = []
answer_parts.append(f"`{name}` 位于 `{filepath}`,使用方法如下:")
# 生成示例代码
param_names = [p['name'] for p in params if p['name'] != 'self']
if param_names:
example_code = f"{name}("
param_examples = []
for p in param_names[:5]: # 最多5个参数
param_examples.append(f"{p}=...")
example_code += ", ".join(param_examples)
example_code += ")"
answer_parts.append(f"\n```python\n{example_code}\n```")
# 参数说明
if params:
answer_parts.append("\n**参数说明**:")
for param in params[:5]:
if param['name'] != 'self':
param_type = param.get('type', 'Any')
param_desc_from_doc = self._extract_param_desc(docstring, param['name'])
answer_parts.append(f"\n- `{param['name']}`: {param_type}")
if param_desc_from_doc:
answer_parts[-1] += f" - {param_desc_from_doc}" # 追加描述
# 添加docstring提示
if docstring:
clean_doc = self._clean_docstring(docstring)[:200]
if clean_doc:
answer_parts.append(f"\n\n**功能简述**:{clean_doc}...")
answer = ''.join(answer_parts)
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
],
metadata={
"task_type": "api_usage",
"element_name": name
}
))
def _generate_project_overview_samples(self):
"""生成项目概览问答 - 基于统计信息"""
stats = self.analysis_data.get('statistics', {})
description = self.project_context.get('description', '')
techs = self.project_context.get('main_technologies', [])
file_type_counts = self.analysis_data.get('statistics', {}).get('file_type_counts', {})
# --- 问题1: 项目主要功能 (更具项目特色) ---
q1_list = [
f"请用一句话描述 {self.project_name} 项目的主要功能。",
f"{self.project_name} 是一个什么样的项目?",
f"简单介绍一下 {self.project_name} 项目。"
]
q1 = random.choice(q1_list)
a1_parts = [
f"{self.project_name} 是一个 Python 项目。"
]
if description:
# 修复:确保项目描述清晰
a1_parts.append(f"\n**核心目标**:\n{description}")
else:
a1_parts.append("\n**核心目标**:此项目旨在提供一个可扩展的多代理系统框架(Agent Framework),支持任务规划、工具调用、消息队列和数据库集成等功能。")
# 添加技术栈
if techs:
a1_parts.append(f"\n\n**主要技术栈**:{', '.join(techs[:5])}等。")
a1_parts.append(f"\n\n项目包含 {stats.get('total_elements', 0)} 个代码元素,主要由 {stats.get('classes', 0)} 个类和 {stats.get('functions', 0) + stats.get('methods', 0)} 个函数/方法构成。")
a1 = ''.join(a1_parts)
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": q1},
{"role": "assistant", "content": a1}
],
metadata={"task_type": "project_overview"}
))
# --- 问题2: 项目结构 ---
q2_list = [
f"{self.project_name} 的项目结构是怎样的?",
f"请列举 {self.project_name} 的核心模块。",
]
q2 = random.choice(q2_list)
a2_parts = [f"{self.project_name} 项目包含以下主要部分:\n"]
# 获取主要模块
modules = self.project_context.get('key_modules', [])
if modules:
a2_parts.append("\n**核心模块**:\n")
for mod in modules[:10]:
a2_parts.append(f"- `{mod}`\n")
else:
a2_parts.append("\n**核心模块**:\n- `core` (核心逻辑,如Agent Runtime, Tooling, Config)\n- `cli` (命令行接口)\n- `llms` (LLM后端实现)\n")
# 优化文件类型展示
if file_type_counts:
file_stats = ', '.join(f'{k.lstrip(".").upper()}: {v}' for k, v in file_type_counts.items() if k not in ['.other'])
a2_parts.append(f"\n**主要文件类型统计**:{file_stats}")
a2 = ''.join(a2_parts)
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": q2},
{"role": "assistant", "content": a2}
],
metadata={"task_type": "project_structure"}
))
# --- 问题3: 核心类/函数 ---
top_elements = sorted(self.code_elements,
key=lambda x: x.get('complexity', 0),
reverse=True)[:10]
q3 = f"{self.project_name} 中有哪些核心类和函数?"
a3_parts = [f"{self.project_name} 的核心组件包括(基于复杂度和重要性):\n"]
for elem in top_elements:
name = elem['name']
filepath = elem['filepath']
elem_type = self._type_to_cn(elem['type'])
doc = elem.get('docstring', '')
short_doc = self._clean_docstring(doc).split('\n')[0][:80].strip()
line = f"\n- `{name}` ({elem_type}):位于 `{filepath}`"
if short_doc:
line += f" - {short_doc}..."
a3_parts.append(line)
if len(top_elements) > 0:
a3 = ''.join(a3_parts)
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": q3},
{"role": "assistant", "content": a3}
],
metadata={"task_type": "core_components"}
))
def _generate_code_location_samples(self):
"""生成代码定位任务"""
# 选择不同文件中的元素
file_elements = defaultdict(list)
for elem in self.code_elements:
# 排除非核心的__init__
if elem['name'] == '__init__' and 'module' not in elem['type']:
continue
file_elements[elem['filepath']].append(elem)
for filepath, elements in list(file_elements.items())[:50]:
# 随机选择1-3个元素
selected = random.sample(elements, min(3, len(elements)))
for elem in selected:
name = elem['name']
elem_type = self._type_to_cn(elem['type'])
question = f"在 {self.project_name} 中,`{name}` {elem_type}在哪个文件里?"
# 答案优化:更简洁,减少冗余信息,模型只需学习路径
answer = f"`{name}` 位于 `{filepath}`。"
self.training_samples.append(TrainingSample(
conversations=[
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
],
metadata={
"task_type": "code_location",
"element_name": name,
"filepath": filepath
}
))
def _extract_signature(self, code: str, element_type: str) -> str:
"""提取函数/类签名"""
if not code:
return ""
lines = code.strip().split('\n')
signature_lines = []
for line in lines:
line = line.strip()
if not line:
continue
signature_lines.append(line)
# 提取函数/方法定义行
if element_type in ['function', 'method'] and (line.startswith('def ') or line.startswith('async def ')):
# 兼容多行函数签名
if not line.endswith(':'):
continue
return '\n'.join(signature_lines)
# 提取类定义行
if element_type == 'class' and line.startswith('class '):
if not line.endswith(':'):
continue
return '\n'.join(signature_lines)
# 避免包含函数/方法体
if line.endswith((':')) and not line.startswith(('def ', 'class ')):
break
# 仅返回前几行,确保只包含定义
return '\n'.join(signature_lines[:5])
def _clean_docstring(self, docstring: str) -> str:
"""清理docstring"""
if not docstring:
return ""
# 移除多余空白
lines = docstring.strip().split('\n')
cleaned = []
for line in lines:
line = line.strip()
if line:
cleaned.append(line)
return ' '.join(cleaned)
def _extract_param_desc(self, docstring: str, param_name: str) -> str:
"""从 docstring 中尝试提取参数描述"""
if not docstring:
return ""
# 匹配各种格式的参数描述,例如 Args: key: The cache key.
match = re.search(rf"(?:Args|Parameters|Params):\s*(?:[\n\r]\s*-)?\s*`?{re.escape(param_name)}`?\s*[:\-]\s*(.*)", docstring, re.IGNORECASE)
if match:
desc = match.group(1).split('\n')[0].strip()
return desc if desc else "无描述"
return ""
def _type_to_cn(self, element_type: str) -> str:
"""元素类型转中文"""
mapping = {
'function': '函数',
'method': '方法',
'class': '类',
'variable': '变量',
'module': '模块'
}
return mapping.get(element_type, element_type)
def save_training_data(self):
"""保存训练数据"""
output_dir = Path(self.config['dataset']['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)
# 打乱
random.shuffle(self.training_samples)
# 分割
total = len(self.training_samples)
train_size = int(total * 0.8)
val_size = int(total * 0.1)
if total < 10: # 如果样本太少,平均分配
train_size = max(1, total // 2)
val_size = max(1, (total - train_size) // 2)
# 再次检查,确保分割不会导致索引错误
if train_size + val_size > total:
val_size = total - train_size
train_data = self.training_samples[:train_size]
val_data = self.training_samples[train_size:train_size + val_size]
test_data = self.training_samples[train_size + val_size:]
# 保存为JSONL
self._save_jsonl(train_data, output_dir / "train.jsonl")
self._save_jsonl(val_data, output_dir / "val.jsonl")
self._save_jsonl(test_data, output_dir / "test.jsonl")
# 元数据
metadata = {
'total_samples': total,
'train_samples': len(train_data),
'val_samples': len(val_data),
'test_samples': len(test_data),
'project_name': self.project_name,
'task_distribution': self._get_task_distribution()
}
with open(output_dir / "metadata.json", 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
print(f"\n✓ Training data saved:")
print(f" Train: {len(train_data)}")
print(f" Val: {len(val_data)}")
print(f" Test: {len(test_data)}")
print(f" Total: {total}")
# 显示样本示例
print(f"\n📝 Sample training example:")
if train_data:
sample = random.choice(train_data)
print(f"Q: {sample.conversations[0]['content'][:100]}...")
print(f"A: {sample.conversations[1]['content'][:150]}...")
def _save_jsonl(self, data: List[TrainingSample], filepath: Path):
"""保存为JSONL格式"""
with open(filepath, 'w', encoding='utf-8') as f:
for sample in data:
# 仅保存对话,不保存 metadata
json.dump({'conversations': sample.conversations}, f, ensure_ascii=False)
f.write('\n')
def _get_task_distribution(self) -> Dict[str, int]:
"""统计任务分布"""
dist = {}
for sample in self.training_samples:
task_type = sample.metadata.get('task_type', 'unknown')
dist[task_type] = dist.get(task_type, 0) + 1
return dist
def main():
print("="*60)
print("Fixed Training Data Generator (Project-Specific Answers Enhanced)")
print("="*60)
generator = FixedDataGenerator()
generator.generate_training_data()
generator.save_training_data()
print("\n" + "="*60)
print("✓ Data generation completed!")
print("="*60)
if __name__ == "__main__":
main()