File size: 2,936 Bytes
4e909c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
更新config.yaml中的代码仓库配置
"""
import yaml
from pathlib import Path
import sys


def update_repository_config(repo_url: str, local_path: str = None):
    """更新config.yaml中的仓库配置"""
    config_file = Path(__file__).parent.parent / "config" / "default_config.yaml"
    #config_file = Path("config.yaml")
    
    if not config_file.exists():
        print("❌ 错误: 找不到 config.yaml")
        return False
    
    # 读取现有配置
    with open(config_file, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    # 提取仓库名
    repo_name = Path(repo_url).stem
    
    # 更新仓库配置
    if local_path is None:
        local_path = f"./repos/{repo_name}"
    
    config['repository']['url'] = repo_url
    config['repository']['local_path'] = local_path
    
    # 备份原配置
    #backup_file = config_file.with_suffix('.yaml.backup')
    backup_file = config_file.parent / (config_file.stem + '.backup.yaml')

    with open(backup_file, 'w', encoding='utf-8') as f:
        yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
    
    print(f"✓ 原配置已备份至: {backup_file}")
    
    # 保存新配置
    with open(config_file, 'w', encoding='utf-8') as f:
        yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
    
    print(f"✓ 配置已更新:")
    print(f"  仓库名称: {repo_name}")
    print(f"  仓库URL: {repo_url}")
    print(f"  本地路径: {local_path}")
    
    return True


def main():
    """主函数"""
    print("="*70)
    print("代码仓库配置更新工具")
    print("="*70)
    print()
    
    if len(sys.argv) > 1:
        repo_url = sys.argv[1]
        local_path = sys.argv[2] if len(sys.argv) > 2 else None
    else:
        print("请输入新的代码仓库URL:")
        print("示例: https://github.com/gamosoft/NoteDiscovery.git")
        print()
        repo_url = input("仓库URL: ").strip()
        
        if not repo_url:
            print("❌ 未输入URL,已取消")
            return
        
        print()
        use_default_path = input("使用默认本地路径? (y/n, 默认y): ").strip().lower()
        
        if use_default_path == 'n':
            local_path = input("本地路径: ").strip()
        else:
            local_path = None
    
    print()
    success = update_repository_config(repo_url, local_path)
    
    if success:
        print()
        print("="*70)
        print("✅ 配置更新成功!")
        print()
        print("下一步:")
        print("  1. 运行知识检测: python test_base_model_knowledge.py")
        print("  2. 如果检测通过,开始训练:")
        print("     python 1_repository_analyzer.py")
        print("     python 2_data_generator.py")
        print("     deepspeed --num_gpus=2 3_model_finetuner_v4_OOM_FIX.py")
        print("="*70)


if __name__ == "__main__":
    main()