Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,6 @@ import torch
|
|
| 5 |
import pandas as pd
|
| 6 |
import faiss
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
-
from transformers import CLIPProcessor, CLIPModel
|
| 9 |
import time
|
| 10 |
|
| 11 |
# 创建安全缓存目录
|
|
@@ -16,58 +15,62 @@ os.makedirs(CACHE_DIR, exist_ok=True)
|
|
| 16 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
| 17 |
torch.set_num_threads(1)
|
| 18 |
|
| 19 |
-
# 全局变量
|
| 20 |
index = None
|
| 21 |
metadata = None
|
| 22 |
-
clip_model = None
|
| 23 |
-
clip_processor = None
|
| 24 |
|
| 25 |
def load_resources():
|
| 26 |
"""加载所有必要资源(768维专用)"""
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
clip_processor = CLIPProcessor.from_pretrained(
|
| 38 |
-
"openai/clip-vit-large-patch14",
|
| 39 |
-
cache_dir=CACHE_DIR,
|
| 40 |
-
token=None
|
| 41 |
-
)
|
| 42 |
-
print("CLIP模型加载完成")
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
if index
|
| 46 |
-
print("
|
| 47 |
-
INDEX_PATH = hf_hub_download(
|
| 48 |
-
repo_id="GOGO198/GOGO_rag_index",
|
| 49 |
-
filename="faiss_index.bin",
|
| 50 |
-
cache_dir=CACHE_DIR,
|
| 51 |
-
token=os.getenv("HF_TOKEN")
|
| 52 |
-
)
|
| 53 |
-
index = faiss.read_index(INDEX_PATH)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
if index
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def predict(vector):
|
| 73 |
"""处理768维向量输入并返回答案"""
|
|
@@ -133,8 +136,13 @@ with gr.Blocks() as demo:
|
|
| 133 |
# 启动应用
|
| 134 |
if __name__ == "__main__":
|
| 135 |
# 预加载资源
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# 确保缓存目录存在
|
| 140 |
# import pathlib
|
|
|
|
| 5 |
import pandas as pd
|
| 6 |
import faiss
|
| 7 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 8 |
import time
|
| 9 |
|
| 10 |
# 创建安全缓存目录
|
|
|
|
| 15 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
| 16 |
torch.set_num_threads(1)
|
| 17 |
|
| 18 |
+
# 全局变量 - 移除了clip_model和clip_processor
|
| 19 |
index = None
|
| 20 |
metadata = None
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def load_resources():
|
| 23 |
"""加载所有必要资源(768维专用)"""
|
| 24 |
+
# 清理残留锁文件
|
| 25 |
+
lock_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.lock')]
|
| 26 |
+
for lock_file in lock_files:
|
| 27 |
+
try:
|
| 28 |
+
os.remove(os.path.join(CACHE_DIR, lock_file))
|
| 29 |
+
print(f"🧹 清理锁文件: {lock_file}")
|
| 30 |
+
except:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
global index, metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# 仅当资源未加载时才初始化
|
| 36 |
+
if not all([index, metadata]):
|
| 37 |
+
print("🔄 正在加载所有资源...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# 加载FAISS索引(768维)
|
| 40 |
+
if index is None:
|
| 41 |
+
print("📥 正在下载FAISS索引...")
|
| 42 |
+
try:
|
| 43 |
+
INDEX_PATH = hf_hub_download(
|
| 44 |
+
repo_id="GOGO198/GOGO_rag_index",
|
| 45 |
+
filename="faiss_index.bin",
|
| 46 |
+
cache_dir=CACHE_DIR,
|
| 47 |
+
token=os.getenv("HF_TOKEN")
|
| 48 |
+
)
|
| 49 |
+
index = faiss.read_index(INDEX_PATH)
|
| 50 |
+
|
| 51 |
+
# 验证索引维度
|
| 52 |
+
if index.d != 768:
|
| 53 |
+
raise ValueError(f"❌ 索引维度错误:预期768维,实际{index.d}维")
|
| 54 |
+
print(f"✅ FAISS索引加载完成 | 维度: {index.d}")
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"❌ FAISS索引加载失败: {str(e)}")
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
# 加载元数据
|
| 60 |
+
if metadata is None:
|
| 61 |
+
print("📄 正在下载元数据...")
|
| 62 |
+
try:
|
| 63 |
+
METADATA_PATH = hf_hub_download(
|
| 64 |
+
repo_id="GOGO198/GOGO_rag_index",
|
| 65 |
+
filename="metadata.csv",
|
| 66 |
+
cache_dir=CACHE_DIR,
|
| 67 |
+
token=os.getenv("HF_TOKEN")
|
| 68 |
+
)
|
| 69 |
+
metadata = pd.read_csv(METADATA_PATH)
|
| 70 |
+
print(f"✅ 元数据加载完成 | 记录数: {len(metadata)}")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"❌ 元数据加载失败: {str(e)}")
|
| 73 |
+
raise
|
| 74 |
|
| 75 |
def predict(vector):
|
| 76 |
"""处理768维向量输入并返回答案"""
|
|
|
|
| 136 |
# 启动应用
|
| 137 |
if __name__ == "__main__":
|
| 138 |
# 预加载资源
|
| 139 |
+
if index is None or metadata is None:
|
| 140 |
+
print("🚀 启动前预加载资源...")
|
| 141 |
+
try:
|
| 142 |
+
load_resources()
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"⛔ 资源加载失败: {str(e)}")
|
| 145 |
+
|
| 146 |
|
| 147 |
# 确保缓存目录存在
|
| 148 |
# import pathlib
|