File size: 2,936 Bytes
4e909c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
"""
更新config.yaml中的代码仓库配置
"""
import yaml
from pathlib import Path
import sys
def update_repository_config(repo_url: str, local_path: str = None):
"""更新config.yaml中的仓库配置"""
config_file = Path(__file__).parent.parent / "config" / "default_config.yaml"
#config_file = Path("config.yaml")
if not config_file.exists():
print("❌ 错误: 找不到 config.yaml")
return False
# 读取现有配置
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 提取仓库名
repo_name = Path(repo_url).stem
# 更新仓库配置
if local_path is None:
local_path = f"./repos/{repo_name}"
config['repository']['url'] = repo_url
config['repository']['local_path'] = local_path
# 备份原配置
#backup_file = config_file.with_suffix('.yaml.backup')
backup_file = config_file.parent / (config_file.stem + '.backup.yaml')
with open(backup_file, 'w', encoding='utf-8') as f:
yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
print(f"✓ 原配置已备份至: {backup_file}")
# 保存新配置
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
print(f"✓ 配置已更新:")
print(f" 仓库名称: {repo_name}")
print(f" 仓库URL: {repo_url}")
print(f" 本地路径: {local_path}")
return True
def main():
"""主函数"""
print("="*70)
print("代码仓库配置更新工具")
print("="*70)
print()
if len(sys.argv) > 1:
repo_url = sys.argv[1]
local_path = sys.argv[2] if len(sys.argv) > 2 else None
else:
print("请输入新的代码仓库URL:")
print("示例: https://github.com/gamosoft/NoteDiscovery.git")
print()
repo_url = input("仓库URL: ").strip()
if not repo_url:
print("❌ 未输入URL,已取消")
return
print()
use_default_path = input("使用默认本地路径? (y/n, 默认y): ").strip().lower()
if use_default_path == 'n':
local_path = input("本地路径: ").strip()
else:
local_path = None
print()
success = update_repository_config(repo_url, local_path)
if success:
print()
print("="*70)
print("✅ 配置更新成功!")
print()
print("下一步:")
print(" 1. 运行知识检测: python test_base_model_knowledge.py")
print(" 2. 如果检测通过,开始训练:")
print(" python 1_repository_analyzer.py")
print(" python 2_data_generator.py")
print(" deepspeed --num_gpus=2 3_model_finetuner_v4_OOM_FIX.py")
print("="*70)
if __name__ == "__main__":
main()
|