|
|
""" |
|
|
更新config.yaml中的代码仓库配置 |
|
|
""" |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
|
|
|
def update_repository_config(repo_url: str, local_path: str = None): |
|
|
"""更新config.yaml中的仓库配置""" |
|
|
config_file = Path(__file__).parent.parent / "config" / "default_config.yaml" |
|
|
|
|
|
|
|
|
if not config_file.exists(): |
|
|
print("❌ 错误: 找不到 config.yaml") |
|
|
return False |
|
|
|
|
|
|
|
|
with open(config_file, 'r', encoding='utf-8') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
repo_name = Path(repo_url).stem |
|
|
|
|
|
|
|
|
if local_path is None: |
|
|
local_path = f"./repos/{repo_name}" |
|
|
|
|
|
config['repository']['url'] = repo_url |
|
|
config['repository']['local_path'] = local_path |
|
|
|
|
|
|
|
|
|
|
|
backup_file = config_file.parent / (config_file.stem + '.backup.yaml') |
|
|
|
|
|
with open(backup_file, 'w', encoding='utf-8') as f: |
|
|
yaml.dump(config, f, allow_unicode=True, default_flow_style=False) |
|
|
|
|
|
print(f"✓ 原配置已备份至: {backup_file}") |
|
|
|
|
|
|
|
|
with open(config_file, 'w', encoding='utf-8') as f: |
|
|
yaml.dump(config, f, allow_unicode=True, default_flow_style=False) |
|
|
|
|
|
print(f"✓ 配置已更新:") |
|
|
print(f" 仓库名称: {repo_name}") |
|
|
print(f" 仓库URL: {repo_url}") |
|
|
print(f" 本地路径: {local_path}") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
print("="*70) |
|
|
print("代码仓库配置更新工具") |
|
|
print("="*70) |
|
|
print() |
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
repo_url = sys.argv[1] |
|
|
local_path = sys.argv[2] if len(sys.argv) > 2 else None |
|
|
else: |
|
|
print("请输入新的代码仓库URL:") |
|
|
print("示例: https://github.com/gamosoft/NoteDiscovery.git") |
|
|
print() |
|
|
repo_url = input("仓库URL: ").strip() |
|
|
|
|
|
if not repo_url: |
|
|
print("❌ 未输入URL,已取消") |
|
|
return |
|
|
|
|
|
print() |
|
|
use_default_path = input("使用默认本地路径? (y/n, 默认y): ").strip().lower() |
|
|
|
|
|
if use_default_path == 'n': |
|
|
local_path = input("本地路径: ").strip() |
|
|
else: |
|
|
local_path = None |
|
|
|
|
|
print() |
|
|
success = update_repository_config(repo_url, local_path) |
|
|
|
|
|
if success: |
|
|
print() |
|
|
print("="*70) |
|
|
print("✅ 配置更新成功!") |
|
|
print() |
|
|
print("下一步:") |
|
|
print(" 1. 运行知识检测: python test_base_model_knowledge.py") |
|
|
print(" 2. 如果检测通过,开始训练:") |
|
|
print(" python 1_repository_analyzer.py") |
|
|
print(" python 2_data_generator.py") |
|
|
print(" deepspeed --num_gpus=2 3_model_finetuner_v4_OOM_FIX.py") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|