|
|
import os |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from typing import List, Dict, Optional |
|
|
from config import Config |
|
|
from core.openai_client import OpenAIClient |
|
|
|
|
|
class MemoryManager: |
|
|
"""向量记忆管理器 - 存储和检索角色相关的文本片段""" |
|
|
|
|
|
def __init__(self, character_name: str): |
|
|
self.character_name = character_name |
|
|
self.client = OpenAIClient.get_client() |
|
|
|
|
|
|
|
|
os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True) |
|
|
|
|
|
try: |
|
|
self.chroma_client = chromadb.Client(Settings( |
|
|
persist_directory=Config.VECTOR_DB_PATH, |
|
|
anonymized_telemetry=False |
|
|
)) |
|
|
except: |
|
|
|
|
|
self.chroma_client = chromadb.PersistentClient( |
|
|
path=Config.VECTOR_DB_PATH |
|
|
) |
|
|
|
|
|
|
|
|
collection_name = f"char_{character_name.replace(' ', '_').lower()}" |
|
|
collection_name = collection_name[:63] |
|
|
|
|
|
try: |
|
|
self.collection = self.chroma_client.get_or_create_collection( |
|
|
name=collection_name, |
|
|
metadata={"character": character_name} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"创建集合时出错: {e}") |
|
|
|
|
|
collection_name = f"char_{hash(character_name) % 10000}" |
|
|
self.collection = self.chroma_client.get_or_create_collection( |
|
|
name=collection_name, |
|
|
metadata={"character": character_name} |
|
|
) |
|
|
|
|
|
def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]): |
|
|
"""添加与角色相关的文本块 |
|
|
|
|
|
Args: |
|
|
chunks: 所有文本块 |
|
|
character_chunks: 角色出现的文本块ID列表 |
|
|
""" |
|
|
|
|
|
documents = [] |
|
|
metadatas = [] |
|
|
ids = [] |
|
|
|
|
|
for chunk_id in character_chunks: |
|
|
if chunk_id < len(chunks): |
|
|
chunk = chunks[chunk_id] |
|
|
documents.append(chunk['text']) |
|
|
metadatas.append({ |
|
|
'chunk_id': chunk_id, |
|
|
'position': chunk['start'] |
|
|
}) |
|
|
ids.append(f"chunk_{chunk_id}") |
|
|
|
|
|
if documents: |
|
|
try: |
|
|
|
|
|
batch_size = 100 |
|
|
for i in range(0, len(documents), batch_size): |
|
|
batch_docs = documents[i:i+batch_size] |
|
|
batch_metas = metadatas[i:i+batch_size] |
|
|
batch_ids = ids[i:i+batch_size] |
|
|
|
|
|
self.collection.add( |
|
|
documents=batch_docs, |
|
|
metadatas=batch_metas, |
|
|
ids=batch_ids |
|
|
) |
|
|
|
|
|
print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库") |
|
|
except Exception as e: |
|
|
print(f"添加文本块到向量库失败: {e}") |
|
|
print("将继续运行,但不使用记忆功能") |
|
|
|
|
|
def search_relevant_context(self, query: str, n_results: int = None) -> List[str]: |
|
|
"""检索与查询相关的上下文 |
|
|
|
|
|
Args: |
|
|
query: 查询文本 |
|
|
n_results: 返回结果数量 |
|
|
|
|
|
Returns: |
|
|
相关文本片段列表 |
|
|
""" |
|
|
|
|
|
n_results = n_results or Config.MAX_MEMORY_RETRIEVAL |
|
|
|
|
|
try: |
|
|
collection_count = self.collection.count() |
|
|
if collection_count == 0: |
|
|
return [] |
|
|
|
|
|
actual_n_results = min(n_results, collection_count) |
|
|
|
|
|
results = self.collection.query( |
|
|
query_texts=[query], |
|
|
n_results=actual_n_results |
|
|
) |
|
|
|
|
|
if results and results['documents']: |
|
|
return results['documents'][0] |
|
|
return [] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"检索失败: {e}") |
|
|
return [] |
|
|
|
|
|
def get_embedding(self, text: str) -> List[float]: |
|
|
"""获取文本嵌入向量 |
|
|
|
|
|
Args: |
|
|
text: 输入文本 |
|
|
|
|
|
Returns: |
|
|
嵌入向量 |
|
|
""" |
|
|
try: |
|
|
response = self.client.embeddings.create( |
|
|
model=Config.EMBEDDING_MODEL, |
|
|
input=text |
|
|
) |
|
|
return response.data[0].embedding |
|
|
except Exception as e: |
|
|
print(f"获取嵌入失败: {e}") |
|
|
return [] |
|
|
|
|
|
def get_statistics(self) -> Dict: |
|
|
"""获取记忆库统计信息 |
|
|
|
|
|
Returns: |
|
|
统计信息字典 |
|
|
""" |
|
|
try: |
|
|
count = self.collection.count() |
|
|
return { |
|
|
'character': self.character_name, |
|
|
'chunk_count': count, |
|
|
'collection_name': self.collection.name |
|
|
} |
|
|
except: |
|
|
return { |
|
|
'character': self.character_name, |
|
|
'chunk_count': 0, |
|
|
'collection_name': 'unknown' |
|
|
} |
|
|
|
|
|
def clear(self): |
|
|
"""清空记忆库""" |
|
|
try: |
|
|
|
|
|
self.chroma_client.delete_collection(self.collection.name) |
|
|
print(f"已清空 {self.character_name} 的记忆库") |
|
|
except Exception as e: |
|
|
print(f"清空记忆库失败: {e}") |