|
|
""" |
|
|
修复版训练数据生成器 |
|
|
核心改进: |
|
|
1. 直接基于代码内容生成准确的问答对 |
|
|
2. 不依赖LLM生成(避免循环依赖) |
|
|
3. 使用模板化方法确保数据质量 |
|
|
4. 优化项目概览问题,使其更具项目特色 |
|
|
""" |
|
|
import json |
|
|
import yaml |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any |
|
|
from dataclasses import dataclass, field |
|
|
import re |
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingSample: |
|
|
"""训练样本""" |
|
|
conversations: List[Dict[str, str]] |
|
|
metadata: Dict[str, Any] |
|
|
|
|
|
|
|
|
class FixedDataGenerator: |
|
|
"""修复版数据生成器 - 基于规则和模板""" |
|
|
|
|
|
def __init__(self, config_path: str = "../config/default_config.yaml", |
|
|
analysis_path: str = "../data/repository_analysis.json"): |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
self.config = yaml.safe_load(f) |
|
|
|
|
|
try: |
|
|
with open(analysis_path, 'r', encoding='utf-8') as f: |
|
|
self.analysis_data = json.load(f) |
|
|
except FileNotFoundError: |
|
|
print(f"❌ 警告: 找不到分析文件 {analysis_path}。请先运行分析器。") |
|
|
self.analysis_data = {'code_elements': [], 'project_context': {}} |
|
|
|
|
|
self.code_elements = self.analysis_data.get('code_elements', []) |
|
|
self.project_context = self.analysis_data.get('project_context', {}) |
|
|
self.project_name = self.project_context.get('project_name', 'Laddr') |
|
|
|
|
|
self.training_samples = [] |
|
|
|
|
|
def generate_training_data(self): |
|
|
"""生成训练数据""" |
|
|
print(f"Generating training data for {self.project_name}...") |
|
|
|
|
|
|
|
|
print("Generating code explanation samples...") |
|
|
self._generate_code_explanation_samples() |
|
|
|
|
|
|
|
|
print("Generating API usage samples...") |
|
|
self._generate_api_usage_samples() |
|
|
|
|
|
|
|
|
print("Generating project overview samples...") |
|
|
self._generate_project_overview_samples() |
|
|
|
|
|
|
|
|
print("Generating code location samples...") |
|
|
self._generate_code_location_samples() |
|
|
|
|
|
print(f"Total samples generated: {len(self.training_samples)}") |
|
|
|
|
|
def _generate_code_explanation_samples(self): |
|
|
"""生成代码解释样本 - 基于真实代码和docstring""" |
|
|
|
|
|
candidates = [e for e in self.code_elements |
|
|
if e.get('docstring') and len(e.get('code', '')) > 50] |
|
|
|
|
|
for element in candidates[:300]: |
|
|
name = element['name'] |
|
|
docstring = element['docstring'] |
|
|
filepath = element['filepath'] |
|
|
element_type = element['type'] |
|
|
code = element.get('code', '') |
|
|
|
|
|
|
|
|
signature = self._extract_signature(code, element_type) |
|
|
|
|
|
|
|
|
questions = [ |
|
|
f"请解释 {self.project_name} 中 `{name}` 的作用。", |
|
|
f"{self.project_name} 的 `{name}` 是做什么的?", |
|
|
f"在 {self.project_name} 项目中,`{name}` 有什么功能?", |
|
|
] |
|
|
question = random.choice(questions) |
|
|
|
|
|
|
|
|
answer_parts = [] |
|
|
|
|
|
|
|
|
answer_parts.append(f"`{name}` 是 {self.project_name} 项目中的一个 {self._type_to_cn(element_type)},位于 `{filepath}`。") |
|
|
|
|
|
|
|
|
if docstring: |
|
|
|
|
|
clean_doc = self._clean_docstring(docstring) |
|
|
answer_parts.append(f"\n**功能描述**:\n{clean_doc}") |
|
|
|
|
|
|
|
|
if signature: |
|
|
answer_parts.append(f"\n**函数签名**:\n```python\n{signature}\n```") |
|
|
|
|
|
|
|
|
params = element.get('parameters', []) |
|
|
if params and len(params) > 0: |
|
|
param_desc = "\n**参数**:\n" |
|
|
for param in params[:5]: |
|
|
param_name = param.get('name', 'unknown') |
|
|
param_type = param.get('type', 'Any') |
|
|
|
|
|
param_desc_from_doc = self._extract_param_desc(docstring, param_name) |
|
|
if param_desc_from_doc: |
|
|
param_info = f"- `{param_name}` ({param_type}): {param_desc_from_doc}\n" |
|
|
else: |
|
|
param_info = f"- `{param_name}` ({param_type})\n" |
|
|
|
|
|
param_desc += param_info |
|
|
answer_parts.append(param_desc) |
|
|
|
|
|
|
|
|
return_type = element.get('return_type') |
|
|
if return_type: |
|
|
answer_parts.append(f"\n**返回值**:`{return_type}`") |
|
|
|
|
|
answer = ''.join(answer_parts) |
|
|
|
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": question}, |
|
|
{"role": "assistant", "content": answer} |
|
|
], |
|
|
metadata={ |
|
|
"task_type": "code_explanation", |
|
|
"element_name": name, |
|
|
"filepath": filepath |
|
|
} |
|
|
)) |
|
|
|
|
|
def _generate_api_usage_samples(self): |
|
|
"""生成API使用示例 - 基于函数签名""" |
|
|
|
|
|
candidates = [e for e in self.code_elements |
|
|
if e['type'] in ['function', 'method'] |
|
|
and not e['name'].startswith('_') |
|
|
and e.get('parameters')] |
|
|
|
|
|
for element in candidates[:150]: |
|
|
name = element['name'] |
|
|
params = element.get('parameters', []) |
|
|
filepath = element['filepath'] |
|
|
docstring = element.get('docstring', '') |
|
|
|
|
|
question = f"如何在 {self.project_name} 中使用 `{name}` 函数?" |
|
|
|
|
|
|
|
|
answer_parts = [] |
|
|
answer_parts.append(f"`{name}` 位于 `{filepath}`,使用方法如下:") |
|
|
|
|
|
|
|
|
param_names = [p['name'] for p in params if p['name'] != 'self'] |
|
|
if param_names: |
|
|
example_code = f"{name}(" |
|
|
param_examples = [] |
|
|
for p in param_names[:5]: |
|
|
param_examples.append(f"{p}=...") |
|
|
example_code += ", ".join(param_examples) |
|
|
example_code += ")" |
|
|
|
|
|
answer_parts.append(f"\n```python\n{example_code}\n```") |
|
|
|
|
|
|
|
|
if params: |
|
|
answer_parts.append("\n**参数说明**:") |
|
|
for param in params[:5]: |
|
|
if param['name'] != 'self': |
|
|
param_type = param.get('type', 'Any') |
|
|
|
|
|
param_desc_from_doc = self._extract_param_desc(docstring, param['name']) |
|
|
|
|
|
answer_parts.append(f"\n- `{param['name']}`: {param_type}") |
|
|
if param_desc_from_doc: |
|
|
answer_parts[-1] += f" - {param_desc_from_doc}" |
|
|
|
|
|
|
|
|
if docstring: |
|
|
clean_doc = self._clean_docstring(docstring)[:200] |
|
|
if clean_doc: |
|
|
answer_parts.append(f"\n\n**功能简述**:{clean_doc}...") |
|
|
|
|
|
answer = ''.join(answer_parts) |
|
|
|
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": question}, |
|
|
{"role": "assistant", "content": answer} |
|
|
], |
|
|
metadata={ |
|
|
"task_type": "api_usage", |
|
|
"element_name": name |
|
|
} |
|
|
)) |
|
|
|
|
|
def _generate_project_overview_samples(self): |
|
|
"""生成项目概览问答 - 基于统计信息""" |
|
|
stats = self.analysis_data.get('statistics', {}) |
|
|
description = self.project_context.get('description', '') |
|
|
techs = self.project_context.get('main_technologies', []) |
|
|
file_type_counts = self.analysis_data.get('statistics', {}).get('file_type_counts', {}) |
|
|
|
|
|
|
|
|
q1_list = [ |
|
|
f"请用一句话描述 {self.project_name} 项目的主要功能。", |
|
|
f"{self.project_name} 是一个什么样的项目?", |
|
|
f"简单介绍一下 {self.project_name} 项目。" |
|
|
] |
|
|
q1 = random.choice(q1_list) |
|
|
|
|
|
a1_parts = [ |
|
|
f"{self.project_name} 是一个 Python 项目。" |
|
|
] |
|
|
|
|
|
if description: |
|
|
|
|
|
a1_parts.append(f"\n**核心目标**:\n{description}") |
|
|
else: |
|
|
a1_parts.append("\n**核心目标**:此项目旨在提供一个可扩展的多代理系统框架(Agent Framework),支持任务规划、工具调用、消息队列和数据库集成等功能。") |
|
|
|
|
|
|
|
|
if techs: |
|
|
a1_parts.append(f"\n\n**主要技术栈**:{', '.join(techs[:5])}等。") |
|
|
|
|
|
a1_parts.append(f"\n\n项目包含 {stats.get('total_elements', 0)} 个代码元素,主要由 {stats.get('classes', 0)} 个类和 {stats.get('functions', 0) + stats.get('methods', 0)} 个函数/方法构成。") |
|
|
|
|
|
a1 = ''.join(a1_parts) |
|
|
|
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": q1}, |
|
|
{"role": "assistant", "content": a1} |
|
|
], |
|
|
metadata={"task_type": "project_overview"} |
|
|
)) |
|
|
|
|
|
|
|
|
q2_list = [ |
|
|
f"{self.project_name} 的项目结构是怎样的?", |
|
|
f"请列举 {self.project_name} 的核心模块。", |
|
|
] |
|
|
q2 = random.choice(q2_list) |
|
|
|
|
|
a2_parts = [f"{self.project_name} 项目包含以下主要部分:\n"] |
|
|
|
|
|
|
|
|
modules = self.project_context.get('key_modules', []) |
|
|
if modules: |
|
|
a2_parts.append("\n**核心模块**:\n") |
|
|
for mod in modules[:10]: |
|
|
a2_parts.append(f"- `{mod}`\n") |
|
|
else: |
|
|
a2_parts.append("\n**核心模块**:\n- `core` (核心逻辑,如Agent Runtime, Tooling, Config)\n- `cli` (命令行接口)\n- `llms` (LLM后端实现)\n") |
|
|
|
|
|
|
|
|
if file_type_counts: |
|
|
file_stats = ', '.join(f'{k.lstrip(".").upper()}: {v}' for k, v in file_type_counts.items() if k not in ['.other']) |
|
|
a2_parts.append(f"\n**主要文件类型统计**:{file_stats}") |
|
|
|
|
|
a2 = ''.join(a2_parts) |
|
|
|
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": q2}, |
|
|
{"role": "assistant", "content": a2} |
|
|
], |
|
|
metadata={"task_type": "project_structure"} |
|
|
)) |
|
|
|
|
|
|
|
|
top_elements = sorted(self.code_elements, |
|
|
key=lambda x: x.get('complexity', 0), |
|
|
reverse=True)[:10] |
|
|
|
|
|
q3 = f"{self.project_name} 中有哪些核心类和函数?" |
|
|
a3_parts = [f"{self.project_name} 的核心组件包括(基于复杂度和重要性):\n"] |
|
|
|
|
|
for elem in top_elements: |
|
|
name = elem['name'] |
|
|
filepath = elem['filepath'] |
|
|
elem_type = self._type_to_cn(elem['type']) |
|
|
|
|
|
doc = elem.get('docstring', '') |
|
|
short_doc = self._clean_docstring(doc).split('\n')[0][:80].strip() |
|
|
|
|
|
line = f"\n- `{name}` ({elem_type}):位于 `{filepath}`" |
|
|
if short_doc: |
|
|
line += f" - {short_doc}..." |
|
|
a3_parts.append(line) |
|
|
|
|
|
if len(top_elements) > 0: |
|
|
a3 = ''.join(a3_parts) |
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": q3}, |
|
|
{"role": "assistant", "content": a3} |
|
|
], |
|
|
metadata={"task_type": "core_components"} |
|
|
)) |
|
|
|
|
|
def _generate_code_location_samples(self): |
|
|
"""生成代码定位任务""" |
|
|
|
|
|
file_elements = defaultdict(list) |
|
|
for elem in self.code_elements: |
|
|
|
|
|
if elem['name'] == '__init__' and 'module' not in elem['type']: |
|
|
continue |
|
|
file_elements[elem['filepath']].append(elem) |
|
|
|
|
|
for filepath, elements in list(file_elements.items())[:50]: |
|
|
|
|
|
selected = random.sample(elements, min(3, len(elements))) |
|
|
|
|
|
for elem in selected: |
|
|
name = elem['name'] |
|
|
elem_type = self._type_to_cn(elem['type']) |
|
|
|
|
|
question = f"在 {self.project_name} 中,`{name}` {elem_type}在哪个文件里?" |
|
|
|
|
|
|
|
|
answer = f"`{name}` 位于 `{filepath}`。" |
|
|
|
|
|
self.training_samples.append(TrainingSample( |
|
|
conversations=[ |
|
|
{"role": "user", "content": question}, |
|
|
{"role": "assistant", "content": answer} |
|
|
], |
|
|
metadata={ |
|
|
"task_type": "code_location", |
|
|
"element_name": name, |
|
|
"filepath": filepath |
|
|
} |
|
|
)) |
|
|
|
|
|
def _extract_signature(self, code: str, element_type: str) -> str: |
|
|
"""提取函数/类签名""" |
|
|
if not code: |
|
|
return "" |
|
|
|
|
|
lines = code.strip().split('\n') |
|
|
signature_lines = [] |
|
|
|
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
signature_lines.append(line) |
|
|
|
|
|
|
|
|
if element_type in ['function', 'method'] and (line.startswith('def ') or line.startswith('async def ')): |
|
|
|
|
|
if not line.endswith(':'): |
|
|
continue |
|
|
return '\n'.join(signature_lines) |
|
|
|
|
|
|
|
|
if element_type == 'class' and line.startswith('class '): |
|
|
if not line.endswith(':'): |
|
|
continue |
|
|
return '\n'.join(signature_lines) |
|
|
|
|
|
|
|
|
if line.endswith((':')) and not line.startswith(('def ', 'class ')): |
|
|
break |
|
|
|
|
|
|
|
|
return '\n'.join(signature_lines[:5]) |
|
|
|
|
|
def _clean_docstring(self, docstring: str) -> str: |
|
|
"""清理docstring""" |
|
|
if not docstring: |
|
|
return "" |
|
|
|
|
|
|
|
|
lines = docstring.strip().split('\n') |
|
|
cleaned = [] |
|
|
for line in lines: |
|
|
line = line.strip() |
|
|
if line: |
|
|
cleaned.append(line) |
|
|
|
|
|
return ' '.join(cleaned) |
|
|
|
|
|
def _extract_param_desc(self, docstring: str, param_name: str) -> str: |
|
|
"""从 docstring 中尝试提取参数描述""" |
|
|
if not docstring: |
|
|
return "" |
|
|
|
|
|
match = re.search(rf"(?:Args|Parameters|Params):\s*(?:[\n\r]\s*-)?\s*`?{re.escape(param_name)}`?\s*[:\-]\s*(.*)", docstring, re.IGNORECASE) |
|
|
if match: |
|
|
desc = match.group(1).split('\n')[0].strip() |
|
|
return desc if desc else "无描述" |
|
|
return "" |
|
|
|
|
|
def _type_to_cn(self, element_type: str) -> str: |
|
|
"""元素类型转中文""" |
|
|
mapping = { |
|
|
'function': '函数', |
|
|
'method': '方法', |
|
|
'class': '类', |
|
|
'variable': '变量', |
|
|
'module': '模块' |
|
|
} |
|
|
return mapping.get(element_type, element_type) |
|
|
|
|
|
def save_training_data(self): |
|
|
"""保存训练数据""" |
|
|
output_dir = Path(self.config['dataset']['output_dir']) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
random.shuffle(self.training_samples) |
|
|
|
|
|
|
|
|
total = len(self.training_samples) |
|
|
train_size = int(total * 0.8) |
|
|
val_size = int(total * 0.1) |
|
|
|
|
|
if total < 10: |
|
|
train_size = max(1, total // 2) |
|
|
val_size = max(1, (total - train_size) // 2) |
|
|
|
|
|
|
|
|
if train_size + val_size > total: |
|
|
val_size = total - train_size |
|
|
|
|
|
train_data = self.training_samples[:train_size] |
|
|
val_data = self.training_samples[train_size:train_size + val_size] |
|
|
test_data = self.training_samples[train_size + val_size:] |
|
|
|
|
|
|
|
|
self._save_jsonl(train_data, output_dir / "train.jsonl") |
|
|
self._save_jsonl(val_data, output_dir / "val.jsonl") |
|
|
self._save_jsonl(test_data, output_dir / "test.jsonl") |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'total_samples': total, |
|
|
'train_samples': len(train_data), |
|
|
'val_samples': len(val_data), |
|
|
'test_samples': len(test_data), |
|
|
'project_name': self.project_name, |
|
|
'task_distribution': self._get_task_distribution() |
|
|
} |
|
|
|
|
|
with open(output_dir / "metadata.json", 'w', encoding='utf-8') as f: |
|
|
json.dump(metadata, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"\n✓ Training data saved:") |
|
|
print(f" Train: {len(train_data)}") |
|
|
print(f" Val: {len(val_data)}") |
|
|
print(f" Test: {len(test_data)}") |
|
|
print(f" Total: {total}") |
|
|
|
|
|
|
|
|
print(f"\n📝 Sample training example:") |
|
|
if train_data: |
|
|
sample = random.choice(train_data) |
|
|
print(f"Q: {sample.conversations[0]['content'][:100]}...") |
|
|
print(f"A: {sample.conversations[1]['content'][:150]}...") |
|
|
|
|
|
def _save_jsonl(self, data: List[TrainingSample], filepath: Path): |
|
|
"""保存为JSONL格式""" |
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
|
for sample in data: |
|
|
|
|
|
json.dump({'conversations': sample.conversations}, f, ensure_ascii=False) |
|
|
f.write('\n') |
|
|
|
|
|
def _get_task_distribution(self) -> Dict[str, int]: |
|
|
"""统计任务分布""" |
|
|
dist = {} |
|
|
for sample in self.training_samples: |
|
|
task_type = sample.metadata.get('task_type', 'unknown') |
|
|
dist[task_type] = dist.get(task_type, 0) + 1 |
|
|
return dist |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*60) |
|
|
print("Fixed Training Data Generator (Project-Specific Answers Enhanced)") |
|
|
print("="*60) |
|
|
|
|
|
generator = FixedDataGenerator() |
|
|
generator.generate_training_data() |
|
|
generator.save_training_data() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("✓ Data generation completed!") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|