FictionAgent / core /memory_manager.py
gdwind's picture
Upload folder using huggingface_hub
a226682 verified
import os
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Optional
from config import Config
from core.openai_client import OpenAIClient
class MemoryManager:
"""向量记忆管理器 - 存储和检索角色相关的文本片段"""
def __init__(self, character_name: str):
self.character_name = character_name
self.client = OpenAIClient.get_client()
# 创建向量数据库目录
os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True)
try:
self.chroma_client = chromadb.Client(Settings(
persist_directory=Config.VECTOR_DB_PATH,
anonymized_telemetry=False
))
except:
# 如果上面的方式失败,尝试使用 PersistentClient
self.chroma_client = chromadb.PersistentClient(
path=Config.VECTOR_DB_PATH
)
# 为每个角色创建独立的集合
collection_name = f"char_{character_name.replace(' ', '_').lower()}"
collection_name = collection_name[:63] # ChromaDB 限制集合名长度
try:
self.collection = self.chroma_client.get_or_create_collection(
name=collection_name,
metadata={"character": character_name}
)
except Exception as e:
print(f"创建集合时出错: {e}")
# 如果创建失败,尝试使用更简单的名称
collection_name = f"char_{hash(character_name) % 10000}"
self.collection = self.chroma_client.get_or_create_collection(
name=collection_name,
metadata={"character": character_name}
)
def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]):
"""添加与角色相关的文本块
Args:
chunks: 所有文本块
character_chunks: 角色出现的文本块ID列表
"""
documents = []
metadatas = []
ids = []
for chunk_id in character_chunks:
if chunk_id < len(chunks):
chunk = chunks[chunk_id]
documents.append(chunk['text'])
metadatas.append({
'chunk_id': chunk_id,
'position': chunk['start']
})
ids.append(f"chunk_{chunk_id}")
if documents:
try:
# 批量添加,避免一次性添加太多
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
batch_metas = metadatas[i:i+batch_size]
batch_ids = ids[i:i+batch_size]
self.collection.add(
documents=batch_docs,
metadatas=batch_metas,
ids=batch_ids
)
print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库")
except Exception as e:
print(f"添加文本块到向量库失败: {e}")
print("将继续运行,但不使用记忆功能")
def search_relevant_context(self, query: str, n_results: int = None) -> List[str]:
"""检索与查询相关的上下文
Args:
query: 查询文本
n_results: 返回结果数量
Returns:
相关文本片段列表
"""
n_results = n_results or Config.MAX_MEMORY_RETRIEVAL
try:
collection_count = self.collection.count()
if collection_count == 0:
return []
actual_n_results = min(n_results, collection_count)
results = self.collection.query(
query_texts=[query],
n_results=actual_n_results
)
if results and results['documents']:
return results['documents'][0]
return []
except Exception as e:
print(f"检索失败: {e}")
return []
def get_embedding(self, text: str) -> List[float]:
"""获取文本嵌入向量
Args:
text: 输入文本
Returns:
嵌入向量
"""
try:
response = self.client.embeddings.create(
model=Config.EMBEDDING_MODEL,
input=text
)
return response.data[0].embedding
except Exception as e:
print(f"获取嵌入失败: {e}")
return []
def get_statistics(self) -> Dict:
"""获取记忆库统计信息
Returns:
统计信息字典
"""
try:
count = self.collection.count()
return {
'character': self.character_name,
'chunk_count': count,
'collection_name': self.collection.name
}
except:
return {
'character': self.character_name,
'chunk_count': 0,
'collection_name': 'unknown'
}
def clear(self):
"""清空记忆库"""
try:
# 删除集合
self.chroma_client.delete_collection(self.collection.name)
print(f"已清空 {self.character_name} 的记忆库")
except Exception as e:
print(f"清空记忆库失败: {e}")