GraphRAG_Backend / database_setup_lite.py
KirkHan's picture
Upload 8 files
a990ce3 verified
"""
轻量级图向量数据库 - 不依赖 ChromaDB,避免超过 Vercel 250MB 限制
使用纯 Python 实现简单的向量搜索
"""
import os
import json
import math
from typing import List, Dict, Optional
import requests
# 延迟导入 sentence_transformers,避免依赖冲突
HAS_SENTENCE_TRANSFORMERS = False
SentenceTransformer = None
def _try_import_sentence_transformers():
"""尝试导入 sentence_transformers"""
global HAS_SENTENCE_TRANSFORMERS, SentenceTransformer
if HAS_SENTENCE_TRANSFORMERS:
return True
try:
from sentence_transformers import SentenceTransformer as ST
SentenceTransformer = ST
HAS_SENTENCE_TRANSFORMERS = True
return True
except (ImportError, RuntimeError, AttributeError) as e:
HAS_SENTENCE_TRANSFORMERS = False
return False
# 简单的内存图数据库
class SimpleGraphDB:
"""简单的内存图数据库模拟"""
def __init__(self):
self.nodes = {} # {node_id: {type, properties}}
self.edges = [] # [{source, target, relationship}]
def add_node(self, node_id: str, node_type: str, properties: Dict):
"""添加节点"""
self.nodes[node_id] = {
"type": node_type,
"properties": properties
}
def add_edge(self, source: str, target: str, relationship: str):
"""添加边"""
self.edges.append({
"source": source,
"target": target,
"relationship": relationship
})
def get_neighbors(self, node_id: str, relationship: Optional[str] = None) -> List[Dict]:
"""获取邻居节点"""
neighbors = []
for edge in self.edges:
if edge["source"] == node_id:
if relationship is None or edge["relationship"] == relationship:
target_node = self.nodes.get(edge["target"], {})
neighbors.append({
"node_id": edge["target"],
"relationship": edge["relationship"],
"properties": target_node.get("properties", {})
})
return neighbors
def find_nodes_by_type(self, node_type: str) -> List[Dict]:
"""根据类型查找节点"""
return [
{"id": node_id, **node_data}
for node_id, node_data in self.nodes.items()
if node_data["type"] == node_type
]
def find_node_by_property(self, node_type: str, property_name: str, property_value: str) -> Optional[Dict]:
"""根据属性查找节点"""
for node_id, node_data in self.nodes.items():
if node_data["type"] == node_type:
props = node_data.get("properties", {})
if props.get(property_name) == property_value:
return {"id": node_id, **node_data}
return None
def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算余弦相似度"""
dot_product = sum(a * b for a, b in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(a * a for a in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
class VectorDB:
"""轻量级向量数据库 - 使用内存存储,不依赖 ChromaDB"""
def __init__(self):
# 文档存储:{id: {content, metadata, embedding}}
self.documents: Dict[str, Dict] = {}
# Embedding 配置
self.embedding_api_base = os.getenv("LLM_API_BASE", "https://api.ai-gaochao.cn/v1")
self.embedding_api_key = os.getenv("LLM_API_KEY", "")
self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
self.use_openai_embedding = bool(self.embedding_api_key)
if self.use_openai_embedding:
print(f"✅ 使用 OpenAI Embeddings API: {self.embedding_model}")
else:
print("ℹ️ 使用简单文本匹配(关键词搜索)")
def _get_openai_embeddings(self, texts: List[str]) -> List[List[float]]:
"""调用 OpenAI Embeddings API 获取向量"""
url = f"{self.embedding_api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.embedding_api_key}"
}
data = {
"input": texts,
"model": self.embedding_model
}
response = requests.post(url, headers=headers, json=data, timeout=30)
response.raise_for_status()
result = response.json()
return [item["embedding"] for item in result["data"]]
def _simple_text_match(self, query: str, document: str) -> float:
"""简单的文本匹配评分(关键词匹配)"""
query_words = set(query.lower().split())
doc_words = set(document.lower().split())
if not query_words:
return 0.0
# 计算匹配的关键词比例
matches = len(query_words & doc_words)
return matches / len(query_words)
def add_documents(self, documents: List[str], ids: List[str], metadatas: List[Dict]):
"""添加文档到向量数据库"""
if self.use_openai_embedding:
# 使用 OpenAI Embeddings API
try:
embeddings = self._get_openai_embeddings(documents)
for doc, doc_id, meta, emb in zip(documents, ids, metadatas, embeddings):
self.documents[doc_id] = {
"content": doc,
"metadata": meta,
"embedding": emb
}
except Exception as e:
print(f"⚠️ OpenAI Embeddings API 调用失败: {e}")
# 回退到简单存储(无 embedding)
for doc, doc_id, meta in zip(documents, ids, metadatas):
self.documents[doc_id] = {
"content": doc,
"metadata": meta,
"embedding": None
}
else:
# 不使用 embedding,只存储文档
for doc, doc_id, meta in zip(documents, ids, metadatas):
self.documents[doc_id] = {
"content": doc,
"metadata": meta,
"embedding": None
}
def search(self, query: str, n_results: int = 5) -> List[Dict]:
"""语义搜索"""
if self.use_openai_embedding:
# 使用向量相似度搜索
try:
query_embedding = self._get_openai_embeddings([query])[0]
# 计算所有文档的相似度
results = []
for doc_id, doc_data in self.documents.items():
if doc_data["embedding"]:
similarity = cosine_similarity(query_embedding, doc_data["embedding"])
results.append({
"content": doc_data["content"],
"metadata": doc_data["metadata"],
"distance": 1 - similarity, # 转换为距离(越小越相似)
"id": doc_id
})
# 按相似度排序
results.sort(key=lambda x: x["distance"])
return results[:n_results]
except Exception as e:
print(f"⚠️ 向量搜索失败,回退到文本匹配: {e}")
# 回退到文本匹配
return self._text_search(query, n_results)
else:
# 使用简单文本匹配
return self._text_search(query, n_results)
def _text_search(self, query: str, n_results: int) -> List[Dict]:
"""简单的文本匹配搜索"""
results = []
for doc_id, doc_data in self.documents.items():
score = self._simple_text_match(query, doc_data["content"])
if score > 0:
results.append({
"content": doc_data["content"],
"metadata": doc_data["metadata"],
"distance": 1 - score, # 转换为距离
"id": doc_id
})
# 按相似度排序
results.sort(key=lambda x: x["distance"])
return results[:n_results]
@property
def collection(self):
"""兼容性属性,模拟 ChromaDB 的 collection 接口"""
class MockCollection:
def __init__(self, vector_db):
self.vector_db = vector_db
def get(self):
"""获取所有文档"""
ids = list(self.vector_db.documents.keys())
documents = [self.vector_db.documents[id]["content"] for id in ids]
metadatas = [self.vector_db.documents[id]["metadata"] for id in ids]
return {
"ids": ids,
"documents": documents,
"metadatas": metadatas
}
return MockCollection(self)
def setup_databases():
"""初始化数据库"""
# 加载数据
with open("mock_data.json", "r", encoding="utf-8") as f:
data = json.load(f)
# 初始化图数据库
graph_db = SimpleGraphDB()
# 添加产品节点
for product in data["products"]:
graph_db.add_node(
product["id"],
"Product",
{
"name": product["name"],
"type": product["type"],
"keywords": product["keywords"],
"features": product["features"]
}
)
# 添加风格节点
for style in data["styles"]:
graph_db.add_node(
style["id"],
"Style",
{
"name": style["name"],
"description": style["description"],
"characteristics": style["characteristics"]
}
)
# 添加文案节点
for copy in data["copywritings"]:
graph_db.add_node(
copy["id"],
"Copywriting",
{
"content": copy["content"],
"tag": copy["tag"],
"target_audience": copy["target_audience"]
}
)
# 添加特征节点
all_features = set()
for product in data["products"]:
for feature in product.get("features", []):
all_features.add(feature)
for feature in all_features:
graph_db.add_node(feature, "Feature", {"name": feature})
# 添加关系
for rel in data["relationships"]:
graph_db.add_edge(
rel["source"],
rel["target"],
rel["relationship"]
)
# 初始化轻量级向量数据库
vector_db = VectorDB()
# 添加文案到向量数据库
documents = []
ids = []
metadatas = []
for copy in data["copywritings"]:
documents.append(copy["content"])
ids.append(copy["id"])
metadatas.append({
"product_id": copy["product_id"],
"style_id": copy["style_id"],
"tag": copy["tag"],
"target_audience": copy["target_audience"]
})
vector_db.add_documents(documents, ids, metadatas)
print(f"✅ 向量数据库已更新,包含 {len(documents)} 个文案")
print("数据库初始化完成!")
print(f"- 图数据库节点数: {len(graph_db.nodes)}")
print(f"- 图数据库边数: {len(graph_db.edges)}")
print(f"- 向量数据库文档数: {len(documents)}")
return graph_db, vector_db