GOGO198 commited on
Commit
5a49e7f
·
verified ·
1 Parent(s): b9911e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -48
app.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import pandas as pd
6
  import faiss
7
  from huggingface_hub import hf_hub_download
8
- from transformers import CLIPProcessor, CLIPModel
9
  import time
10
 
11
  # 创建安全缓存目录
@@ -16,58 +15,62 @@ os.makedirs(CACHE_DIR, exist_ok=True)
16
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
17
  torch.set_num_threads(1)
18
 
19
- # 全局变量
20
  index = None
21
  metadata = None
22
- clip_model = None
23
- clip_processor = None
24
 
25
  def load_resources():
26
  """加载所有必要资源(768维专用)"""
27
- global index, metadata, clip_model, clip_processor
28
-
29
- # 加载CLIP模型(用于维度验证)
30
- if clip_model is None:
31
- print("正在加载CLIP模型...")
32
- clip_model = CLIPModel.from_pretrained(
33
- "openai/clip-vit-large-patch14",
34
- cache_dir=CACHE_DIR,
35
- token=None
36
- )
37
- clip_processor = CLIPProcessor.from_pretrained(
38
- "openai/clip-vit-large-patch14",
39
- cache_dir=CACHE_DIR,
40
- token=None
41
- )
42
- print("CLIP模型加载完成")
43
 
44
- # 加载FAISS索引(768维)
45
- if index is None:
46
- print("正在下载FAISS索引...")
47
- INDEX_PATH = hf_hub_download(
48
- repo_id="GOGO198/GOGO_rag_index",
49
- filename="faiss_index.bin",
50
- cache_dir=CACHE_DIR,
51
- token=os.getenv("HF_TOKEN")
52
- )
53
- index = faiss.read_index(INDEX_PATH)
54
 
55
- # 验证索引维度
56
- if index.d != 768:
57
- raise ValueError(f"索引维度错误:预期768维,实际{index.d}维")
58
- print("FAISS索引加载完成 | 维度: 768")
59
-
60
- # 加载元数据
61
- if metadata is None:
62
- print("正在下载元数据...")
63
- METADATA_PATH = hf_hub_download(
64
- repo_id="GOGO198/GOGO_rag_index",
65
- filename="metadata.csv",
66
- cache_dir=CACHE_DIR,
67
- token=os.getenv("HF_TOKEN")
68
- )
69
- metadata = pd.read_csv(METADATA_PATH)
70
- print("元数据加载完成")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def predict(vector):
73
  """处理768维向量输入并返回答案"""
@@ -133,8 +136,13 @@ with gr.Blocks() as demo:
133
  # 启动应用
134
  if __name__ == "__main__":
135
  # 预加载资源
136
- print("启动前预加载资源...")
137
- load_resources()
 
 
 
 
 
138
 
139
  # 确保缓存目录存在
140
  # import pathlib
 
5
  import pandas as pd
6
  import faiss
7
  from huggingface_hub import hf_hub_download
 
8
  import time
9
 
10
  # 创建安全缓存目录
 
15
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
16
  torch.set_num_threads(1)
17
 
18
+ # 全局变量 - 移除了clip_model和clip_processor
19
  index = None
20
  metadata = None
 
 
21
 
22
  def load_resources():
23
  """加载所有必要资源(768维专用)"""
24
+ # 清理残留锁文件
25
+ lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
26
+ for lock_file in lock_files:
27
+ try:
28
+ os.remove(os.path.join(CACHE_DIR, lock_file))
29
+ print(f"🧹 清理锁文件: {lock_file}")
30
+ except:
31
+ pass
32
+
33
+ global index, metadata
 
 
 
 
 
 
34
 
35
+ # 仅当资源未加载时才初始化
36
+ if not all([index, metadata]):
37
+ print("🔄 正在加载所有资源...")
 
 
 
 
 
 
 
38
 
39
+ # 加载FAISS索引(768维)
40
+ if index is None:
41
+ print("📥 正在下载FAISS索引...")
42
+ try:
43
+ INDEX_PATH = hf_hub_download(
44
+ repo_id="GOGO198/GOGO_rag_index",
45
+ filename="faiss_index.bin",
46
+ cache_dir=CACHE_DIR,
47
+ token=os.getenv("HF_TOKEN")
48
+ )
49
+ index = faiss.read_index(INDEX_PATH)
50
+
51
+ # 验证索引维度
52
+ if index.d != 768:
53
+ raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")
54
+ print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
55
+ except Exception as e:
56
+ print(f"❌ FAISS索引加载失败: {str(e)}")
57
+ raise
58
+
59
+ # 加载元数据
60
+ if metadata is None:
61
+ print("📄 正在下载元数据...")
62
+ try:
63
+ METADATA_PATH = hf_hub_download(
64
+ repo_id="GOGO198/GOGO_rag_index",
65
+ filename="metadata.csv",
66
+ cache_dir=CACHE_DIR,
67
+ token=os.getenv("HF_TOKEN")
68
+ )
69
+ metadata = pd.read_csv(METADATA_PATH)
70
+ print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
71
+ except Exception as e:
72
+ print(f"❌ 元数据加载失败: {str(e)}")
73
+ raise
74
 
75
  def predict(vector):
76
  """处理768维向量输入并返回答案"""
 
136
  # 启动应用
137
  if __name__ == "__main__":
138
  # 预加载资源
139
+ if index is None or metadata is None:
140
+ print("🚀 启动前预加载资源...")
141
+ try:
142
+ load_resources()
143
+ except Exception as e:
144
+ print(f"⛔ 资源加载失败: {str(e)}")
145
+
146
 
147
  # 确保缓存目录存在
148
  # import pathlib