Spaces:
Sleeping
Sleeping
| """ | |
| 轻量级图向量数据库 - 不依赖 ChromaDB,避免超过 Vercel 250MB 限制 | |
| 使用纯 Python 实现简单的向量搜索 | |
| """ | |
| import os | |
| import json | |
| import math | |
| from typing import List, Dict, Optional | |
| import requests | |
| # 延迟导入 sentence_transformers,避免依赖冲突 | |
| HAS_SENTENCE_TRANSFORMERS = False | |
| SentenceTransformer = None | |
| def _try_import_sentence_transformers(): | |
| """尝试导入 sentence_transformers""" | |
| global HAS_SENTENCE_TRANSFORMERS, SentenceTransformer | |
| if HAS_SENTENCE_TRANSFORMERS: | |
| return True | |
| try: | |
| from sentence_transformers import SentenceTransformer as ST | |
| SentenceTransformer = ST | |
| HAS_SENTENCE_TRANSFORMERS = True | |
| return True | |
| except (ImportError, RuntimeError, AttributeError) as e: | |
| HAS_SENTENCE_TRANSFORMERS = False | |
| return False | |
| # 简单的内存图数据库 | |
| class SimpleGraphDB: | |
| """简单的内存图数据库模拟""" | |
| def __init__(self): | |
| self.nodes = {} # {node_id: {type, properties}} | |
| self.edges = [] # [{source, target, relationship}] | |
| def add_node(self, node_id: str, node_type: str, properties: Dict): | |
| """添加节点""" | |
| self.nodes[node_id] = { | |
| "type": node_type, | |
| "properties": properties | |
| } | |
| def add_edge(self, source: str, target: str, relationship: str): | |
| """添加边""" | |
| self.edges.append({ | |
| "source": source, | |
| "target": target, | |
| "relationship": relationship | |
| }) | |
| def get_neighbors(self, node_id: str, relationship: Optional[str] = None) -> List[Dict]: | |
| """获取邻居节点""" | |
| neighbors = [] | |
| for edge in self.edges: | |
| if edge["source"] == node_id: | |
| if relationship is None or edge["relationship"] == relationship: | |
| target_node = self.nodes.get(edge["target"], {}) | |
| neighbors.append({ | |
| "node_id": edge["target"], | |
| "relationship": edge["relationship"], | |
| "properties": target_node.get("properties", {}) | |
| }) | |
| return neighbors | |
| def find_nodes_by_type(self, node_type: str) -> List[Dict]: | |
| """根据类型查找节点""" | |
| return [ | |
| {"id": node_id, **node_data} | |
| for node_id, node_data in self.nodes.items() | |
| if node_data["type"] == node_type | |
| ] | |
| def find_node_by_property(self, node_type: str, property_name: str, property_value: str) -> Optional[Dict]: | |
| """根据属性查找节点""" | |
| for node_id, node_data in self.nodes.items(): | |
| if node_data["type"] == node_type: | |
| props = node_data.get("properties", {}) | |
| if props.get(property_name) == property_value: | |
| return {"id": node_id, **node_data} | |
| return None | |
| def cosine_similarity(vec1: List[float], vec2: List[float]) -> float: | |
| """计算余弦相似度""" | |
| dot_product = sum(a * b for a, b in zip(vec1, vec2)) | |
| magnitude1 = math.sqrt(sum(a * a for a in vec1)) | |
| magnitude2 = math.sqrt(sum(a * a for a in vec2)) | |
| if magnitude1 == 0 or magnitude2 == 0: | |
| return 0.0 | |
| return dot_product / (magnitude1 * magnitude2) | |
| class VectorDB: | |
| """轻量级向量数据库 - 使用内存存储,不依赖 ChromaDB""" | |
| def __init__(self): | |
| # 文档存储:{id: {content, metadata, embedding}} | |
| self.documents: Dict[str, Dict] = {} | |
| # Embedding 配置 | |
| self.embedding_api_base = os.getenv("LLM_API_BASE", "https://api.ai-gaochao.cn/v1") | |
| self.embedding_api_key = os.getenv("LLM_API_KEY", "") | |
| self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") | |
| self.use_openai_embedding = bool(self.embedding_api_key) | |
| if self.use_openai_embedding: | |
| print(f"✅ 使用 OpenAI Embeddings API: {self.embedding_model}") | |
| else: | |
| print("ℹ️ 使用简单文本匹配(关键词搜索)") | |
| def _get_openai_embeddings(self, texts: List[str]) -> List[List[float]]: | |
| """调用 OpenAI Embeddings API 获取向量""" | |
| url = f"{self.embedding_api_base}/embeddings" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.embedding_api_key}" | |
| } | |
| data = { | |
| "input": texts, | |
| "model": self.embedding_model | |
| } | |
| response = requests.post(url, headers=headers, json=data, timeout=30) | |
| response.raise_for_status() | |
| result = response.json() | |
| return [item["embedding"] for item in result["data"]] | |
| def _simple_text_match(self, query: str, document: str) -> float: | |
| """简单的文本匹配评分(关键词匹配)""" | |
| query_words = set(query.lower().split()) | |
| doc_words = set(document.lower().split()) | |
| if not query_words: | |
| return 0.0 | |
| # 计算匹配的关键词比例 | |
| matches = len(query_words & doc_words) | |
| return matches / len(query_words) | |
| def add_documents(self, documents: List[str], ids: List[str], metadatas: List[Dict]): | |
| """添加文档到向量数据库""" | |
| if self.use_openai_embedding: | |
| # 使用 OpenAI Embeddings API | |
| try: | |
| embeddings = self._get_openai_embeddings(documents) | |
| for doc, doc_id, meta, emb in zip(documents, ids, metadatas, embeddings): | |
| self.documents[doc_id] = { | |
| "content": doc, | |
| "metadata": meta, | |
| "embedding": emb | |
| } | |
| except Exception as e: | |
| print(f"⚠️ OpenAI Embeddings API 调用失败: {e}") | |
| # 回退到简单存储(无 embedding) | |
| for doc, doc_id, meta in zip(documents, ids, metadatas): | |
| self.documents[doc_id] = { | |
| "content": doc, | |
| "metadata": meta, | |
| "embedding": None | |
| } | |
| else: | |
| # 不使用 embedding,只存储文档 | |
| for doc, doc_id, meta in zip(documents, ids, metadatas): | |
| self.documents[doc_id] = { | |
| "content": doc, | |
| "metadata": meta, | |
| "embedding": None | |
| } | |
| def search(self, query: str, n_results: int = 5) -> List[Dict]: | |
| """语义搜索""" | |
| if self.use_openai_embedding: | |
| # 使用向量相似度搜索 | |
| try: | |
| query_embedding = self._get_openai_embeddings([query])[0] | |
| # 计算所有文档的相似度 | |
| results = [] | |
| for doc_id, doc_data in self.documents.items(): | |
| if doc_data["embedding"]: | |
| similarity = cosine_similarity(query_embedding, doc_data["embedding"]) | |
| results.append({ | |
| "content": doc_data["content"], | |
| "metadata": doc_data["metadata"], | |
| "distance": 1 - similarity, # 转换为距离(越小越相似) | |
| "id": doc_id | |
| }) | |
| # 按相似度排序 | |
| results.sort(key=lambda x: x["distance"]) | |
| return results[:n_results] | |
| except Exception as e: | |
| print(f"⚠️ 向量搜索失败,回退到文本匹配: {e}") | |
| # 回退到文本匹配 | |
| return self._text_search(query, n_results) | |
| else: | |
| # 使用简单文本匹配 | |
| return self._text_search(query, n_results) | |
| def _text_search(self, query: str, n_results: int) -> List[Dict]: | |
| """简单的文本匹配搜索""" | |
| results = [] | |
| for doc_id, doc_data in self.documents.items(): | |
| score = self._simple_text_match(query, doc_data["content"]) | |
| if score > 0: | |
| results.append({ | |
| "content": doc_data["content"], | |
| "metadata": doc_data["metadata"], | |
| "distance": 1 - score, # 转换为距离 | |
| "id": doc_id | |
| }) | |
| # 按相似度排序 | |
| results.sort(key=lambda x: x["distance"]) | |
| return results[:n_results] | |
| def collection(self): | |
| """兼容性属性,模拟 ChromaDB 的 collection 接口""" | |
| class MockCollection: | |
| def __init__(self, vector_db): | |
| self.vector_db = vector_db | |
| def get(self): | |
| """获取所有文档""" | |
| ids = list(self.vector_db.documents.keys()) | |
| documents = [self.vector_db.documents[id]["content"] for id in ids] | |
| metadatas = [self.vector_db.documents[id]["metadata"] for id in ids] | |
| return { | |
| "ids": ids, | |
| "documents": documents, | |
| "metadatas": metadatas | |
| } | |
| return MockCollection(self) | |
| def setup_databases(): | |
| """初始化数据库""" | |
| # 加载数据 | |
| with open("mock_data.json", "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # 初始化图数据库 | |
| graph_db = SimpleGraphDB() | |
| # 添加产品节点 | |
| for product in data["products"]: | |
| graph_db.add_node( | |
| product["id"], | |
| "Product", | |
| { | |
| "name": product["name"], | |
| "type": product["type"], | |
| "keywords": product["keywords"], | |
| "features": product["features"] | |
| } | |
| ) | |
| # 添加风格节点 | |
| for style in data["styles"]: | |
| graph_db.add_node( | |
| style["id"], | |
| "Style", | |
| { | |
| "name": style["name"], | |
| "description": style["description"], | |
| "characteristics": style["characteristics"] | |
| } | |
| ) | |
| # 添加文案节点 | |
| for copy in data["copywritings"]: | |
| graph_db.add_node( | |
| copy["id"], | |
| "Copywriting", | |
| { | |
| "content": copy["content"], | |
| "tag": copy["tag"], | |
| "target_audience": copy["target_audience"] | |
| } | |
| ) | |
| # 添加特征节点 | |
| all_features = set() | |
| for product in data["products"]: | |
| for feature in product.get("features", []): | |
| all_features.add(feature) | |
| for feature in all_features: | |
| graph_db.add_node(feature, "Feature", {"name": feature}) | |
| # 添加关系 | |
| for rel in data["relationships"]: | |
| graph_db.add_edge( | |
| rel["source"], | |
| rel["target"], | |
| rel["relationship"] | |
| ) | |
| # 初始化轻量级向量数据库 | |
| vector_db = VectorDB() | |
| # 添加文案到向量数据库 | |
| documents = [] | |
| ids = [] | |
| metadatas = [] | |
| for copy in data["copywritings"]: | |
| documents.append(copy["content"]) | |
| ids.append(copy["id"]) | |
| metadatas.append({ | |
| "product_id": copy["product_id"], | |
| "style_id": copy["style_id"], | |
| "tag": copy["tag"], | |
| "target_audience": copy["target_audience"] | |
| }) | |
| vector_db.add_documents(documents, ids, metadatas) | |
| print(f"✅ 向量数据库已更新,包含 {len(documents)} 个文案") | |
| print("数据库初始化完成!") | |
| print(f"- 图数据库节点数: {len(graph_db.nodes)}") | |
| print(f"- 图数据库边数: {len(graph_db.edges)}") | |
| print(f"- 向量数据库文档数: {len(documents)}") | |
| return graph_db, vector_db | |