File size: 8,711 Bytes
f7892e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

"""RAG (Retrieval Augmented Generation) services for ScriptVoice."""

import os
import json
import pickle
from typing import List, Dict, Any, Tuple, Optional
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from config import PROJECTS_FILE


class RAGService:
    """Handles vector database operations and content retrieval."""
    
    def __init__(self):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
            separators=["\n\n", "\n", ". ", "! ", "? ", " "]
        )
        self.index = None
        self.documents = []
        self.metadata = []
        self.index_file = "vector_index.faiss"
        self.metadata_file = "vector_metadata.pkl"
        self._load_or_create_index()
    
    def _load_or_create_index(self):
        """Load existing index or create new one."""
        if os.path.exists(self.index_file) and os.path.exists(self.metadata_file):
            try:
                self.index = faiss.read_index(self.index_file)
                with open(self.metadata_file, 'rb') as f:
                    data = pickle.load(f)
                    self.documents = data['documents']
                    self.metadata = data['metadata']
                print(f"Loaded vector index with {len(self.documents)} documents")
            except Exception as e:
                print(f"Error loading index: {e}")
                self._create_empty_index()
        else:
            self._create_empty_index()
    
    def _create_empty_index(self):
        """Create empty FAISS index."""
        dimension = 384  # all-MiniLM-L6-v2 dimension
        self.index = faiss.IndexFlatIP(dimension)
        self.documents = []
        self.metadata = []
    
    def chunk_content(self, content: str, content_type: str, content_id: str, title: str) -> List[Document]:
        """Split content into chunks for embedding."""
        chunks = self.text_splitter.split_text(content)
        documents = []
        
        for i, chunk in enumerate(chunks):
            doc = Document(
                page_content=chunk,
                metadata={
                    'content_type': content_type,
                    'content_id': content_id,
                    'title': title,
                    'chunk_id': i,
                    'chunk_count': len(chunks)
                }
            )
            documents.append(doc)
        
        return documents
    
    def add_content(self, content: str, content_type: str, content_id: str, title: str):
        """Add content to the vector database."""
        if not content.strip():
            return
        
        # Remove existing content for this ID
        self.remove_content(content_id)
        
        # Chunk the content
        documents = self.chunk_content(content, content_type, content_id, title)
        
        if not documents:
            return
        
        # Generate embeddings
        texts = [doc.page_content for doc in documents]
        embeddings = self.model.encode(texts)
        
        # Normalize embeddings for cosine similarity
        embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
        
        # Add to FAISS index
        self.index.add(embeddings.astype('float32'))
        
        # Store documents and metadata
        self.documents.extend(documents)
        for doc in documents:
            self.metadata.append(doc.metadata)
        
        # Save index
        self._save_index()
    
    def remove_content(self, content_id: str):
        """Remove content from vector database."""
        indices_to_remove = []
        for i, metadata in enumerate(self.metadata):
            if metadata.get('content_id') == content_id:
                indices_to_remove.append(i)
        
        if indices_to_remove:
            # Rebuild index without removed items
            new_documents = []
            new_metadata = []
            new_embeddings = []
            
            for i, (doc, meta) in enumerate(zip(self.documents, self.metadata)):
                if i not in indices_to_remove:
                    new_documents.append(doc)
                    new_metadata.append(meta)
                    embedding = self.model.encode([doc.page_content])
                    embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True)
                    new_embeddings.append(embedding[0])
            
            # Recreate index
            self._create_empty_index()
            if new_embeddings:
                embeddings_array = np.array(new_embeddings).astype('float32')
                self.index.add(embeddings_array)
                self.documents = new_documents
                self.metadata = new_metadata
            
            self._save_index()
    
    def search(self, query: str, k: int = 5, content_type: Optional[str] = None) -> List[Dict[str, Any]]:
        """Search for similar content."""
        if self.index.ntotal == 0:
            return []
        
        # Generate query embedding
        query_embedding = self.model.encode([query])
        query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
        
        # Search
        scores, indices = self.index.search(query_embedding.astype('float32'), min(k * 2, self.index.ntotal))
        
        results = []
        for score, idx in zip(scores[0], indices[0]):
            if idx >= 0 and idx < len(self.documents):
                metadata = self.metadata[idx]
                
                # Filter by content type if specified
                if content_type and metadata.get('content_type') != content_type:
                    continue
                
                result = {
                    'content': self.documents[idx].page_content,
                    'metadata': metadata,
                    'score': float(score)
                }
                results.append(result)
                
                if len(results) >= k:
                    break
        
        return results
    
    def get_context_for_content(self, content_id: str, query: str, k: int = 3) -> List[Dict[str, Any]]:
        """Get relevant context from other content for a specific item."""
        results = self.search(query, k=k)
        # Filter out results from the same content
        filtered_results = [r for r in results if r['metadata'].get('content_id') != content_id]
        return filtered_results[:k]
    
    def _save_index(self):
        """Save FAISS index and metadata to disk."""
        try:
            faiss.write_index(self.index, self.index_file)
            with open(self.metadata_file, 'wb') as f:
                pickle.dump({
                    'documents': self.documents,
                    'metadata': self.metadata
                }, f)
        except Exception as e:
            print(f"Error saving index: {e}")
    
    def rebuild_index_from_projects(self):
        """Rebuild the entire vector index from current projects data."""
        from models import load_projects
        
        # Clear existing index
        self._create_empty_index()
        
        # Load all projects data
        data = load_projects()
        
        # Add stories
        for story_id, story in data.get("stories", {}).items():
            content = f"{story['title']}\n\n{story['description']}\n\n{story['content']}"
            self.add_content(content, "story", story_id, story['title'])
        
        # Add characters
        for char_id, char in data.get("characters", {}).items():
            content = f"{char['name']}\n\n{char['description']}\n\nTraits: {', '.join(char.get('traits', []))}\n\n{char.get('notes', '')}"
            self.add_content(content, "character", char_id, char['name'])
        
        # Add world elements
        for elem_id, elem in data.get("world_elements", {}).items():
            content = f"{elem['name']} ({elem['type']})\n\n{elem['description']}\n\nTags: {', '.join(elem.get('tags', []))}\n\n{elem.get('notes', '')}"
            self.add_content(content, "world_element", elem_id, elem['name'])
        
        # Add scripts
        for proj_id, proj in data.get("projects", {}).items():
            if proj.get('content'):
                content = f"{proj['name']}\n\n{proj['content']}\n\nNotes: {proj.get('notes', '')}"
                self.add_content(content, "script", proj_id, proj['name'])


# Global RAG service instance
rag_service = RAGService()