Spaces:
Sleeping
Sleeping
File size: 8,711 Bytes
f7892e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
"""RAG (Retrieval Augmented Generation) services for ScriptVoice."""
import os
import json
import pickle
from typing import List, Dict, Any, Tuple, Optional
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from config import PROJECTS_FILE
class RAGService:
"""Handles vector database operations and content retrieval."""
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ". ", "! ", "? ", " "]
)
self.index = None
self.documents = []
self.metadata = []
self.index_file = "vector_index.faiss"
self.metadata_file = "vector_metadata.pkl"
self._load_or_create_index()
def _load_or_create_index(self):
"""Load existing index or create new one."""
if os.path.exists(self.index_file) and os.path.exists(self.metadata_file):
try:
self.index = faiss.read_index(self.index_file)
with open(self.metadata_file, 'rb') as f:
data = pickle.load(f)
self.documents = data['documents']
self.metadata = data['metadata']
print(f"Loaded vector index with {len(self.documents)} documents")
except Exception as e:
print(f"Error loading index: {e}")
self._create_empty_index()
else:
self._create_empty_index()
def _create_empty_index(self):
"""Create empty FAISS index."""
dimension = 384 # all-MiniLM-L6-v2 dimension
self.index = faiss.IndexFlatIP(dimension)
self.documents = []
self.metadata = []
def chunk_content(self, content: str, content_type: str, content_id: str, title: str) -> List[Document]:
"""Split content into chunks for embedding."""
chunks = self.text_splitter.split_text(content)
documents = []
for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk,
metadata={
'content_type': content_type,
'content_id': content_id,
'title': title,
'chunk_id': i,
'chunk_count': len(chunks)
}
)
documents.append(doc)
return documents
def add_content(self, content: str, content_type: str, content_id: str, title: str):
"""Add content to the vector database."""
if not content.strip():
return
# Remove existing content for this ID
self.remove_content(content_id)
# Chunk the content
documents = self.chunk_content(content, content_type, content_id, title)
if not documents:
return
# Generate embeddings
texts = [doc.page_content for doc in documents]
embeddings = self.model.encode(texts)
# Normalize embeddings for cosine similarity
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Add to FAISS index
self.index.add(embeddings.astype('float32'))
# Store documents and metadata
self.documents.extend(documents)
for doc in documents:
self.metadata.append(doc.metadata)
# Save index
self._save_index()
def remove_content(self, content_id: str):
"""Remove content from vector database."""
indices_to_remove = []
for i, metadata in enumerate(self.metadata):
if metadata.get('content_id') == content_id:
indices_to_remove.append(i)
if indices_to_remove:
# Rebuild index without removed items
new_documents = []
new_metadata = []
new_embeddings = []
for i, (doc, meta) in enumerate(zip(self.documents, self.metadata)):
if i not in indices_to_remove:
new_documents.append(doc)
new_metadata.append(meta)
embedding = self.model.encode([doc.page_content])
embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True)
new_embeddings.append(embedding[0])
# Recreate index
self._create_empty_index()
if new_embeddings:
embeddings_array = np.array(new_embeddings).astype('float32')
self.index.add(embeddings_array)
self.documents = new_documents
self.metadata = new_metadata
self._save_index()
def search(self, query: str, k: int = 5, content_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""Search for similar content."""
if self.index.ntotal == 0:
return []
# Generate query embedding
query_embedding = self.model.encode([query])
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
# Search
scores, indices = self.index.search(query_embedding.astype('float32'), min(k * 2, self.index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and idx < len(self.documents):
metadata = self.metadata[idx]
# Filter by content type if specified
if content_type and metadata.get('content_type') != content_type:
continue
result = {
'content': self.documents[idx].page_content,
'metadata': metadata,
'score': float(score)
}
results.append(result)
if len(results) >= k:
break
return results
def get_context_for_content(self, content_id: str, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""Get relevant context from other content for a specific item."""
results = self.search(query, k=k)
# Filter out results from the same content
filtered_results = [r for r in results if r['metadata'].get('content_id') != content_id]
return filtered_results[:k]
def _save_index(self):
"""Save FAISS index and metadata to disk."""
try:
faiss.write_index(self.index, self.index_file)
with open(self.metadata_file, 'wb') as f:
pickle.dump({
'documents': self.documents,
'metadata': self.metadata
}, f)
except Exception as e:
print(f"Error saving index: {e}")
def rebuild_index_from_projects(self):
"""Rebuild the entire vector index from current projects data."""
from models import load_projects
# Clear existing index
self._create_empty_index()
# Load all projects data
data = load_projects()
# Add stories
for story_id, story in data.get("stories", {}).items():
content = f"{story['title']}\n\n{story['description']}\n\n{story['content']}"
self.add_content(content, "story", story_id, story['title'])
# Add characters
for char_id, char in data.get("characters", {}).items():
content = f"{char['name']}\n\n{char['description']}\n\nTraits: {', '.join(char.get('traits', []))}\n\n{char.get('notes', '')}"
self.add_content(content, "character", char_id, char['name'])
# Add world elements
for elem_id, elem in data.get("world_elements", {}).items():
content = f"{elem['name']} ({elem['type']})\n\n{elem['description']}\n\nTags: {', '.join(elem.get('tags', []))}\n\n{elem.get('notes', '')}"
self.add_content(content, "world_element", elem_id, elem['name'])
# Add scripts
for proj_id, proj in data.get("projects", {}).items():
if proj.get('content'):
content = f"{proj['name']}\n\n{proj['content']}\n\nNotes: {proj.get('notes', '')}"
self.add_content(content, "script", proj_id, proj['name'])
# Global RAG service instance
rag_service = RAGService()
|