Jan Biermeyer commited on
Commit
3379400
·
1 Parent(s): 34fc1eb
Files changed (2) hide show
  1. app.py +3 -3
  2. rag/rag.py +12 -12
app.py CHANGED
@@ -17,7 +17,7 @@ import base64
17
  # Add project root to path for imports
18
  project_root = Path(__file__).parent
19
  sys.path.insert(0, str(project_root))
20
- from rag.rag_m2max import get_supra_rag_m2max
21
  from rag.model_loader import load_enhanced_model_m2max, get_model_info
22
 
23
  # Page configuration
@@ -645,7 +645,7 @@ def call_enhanced_model_with_rag(prompt: str) -> tuple[Optional[str], float]:
645
  model, tokenizer = load_enhanced_model_m2max()
646
 
647
  # Get RAG instance
648
- rag = get_supra_rag_m2max()
649
 
650
  # Generate response with RAG context
651
  response = rag.generate_response(prompt, model, tokenizer)
@@ -752,7 +752,7 @@ def main():
752
 
753
  # RAG Status
754
  try:
755
- rag = get_supra_rag_m2max()
756
  rag_count = len(rag.collection.get()['ids'])
757
  st.markdown(f"""
758
  <div class="metric-card">
 
17
  # Add project root to path for imports
18
  project_root = Path(__file__).parent
19
  sys.path.insert(0, str(project_root))
20
+ from rag.rag import get_supra_rag
21
  from rag.model_loader import load_enhanced_model_m2max, get_model_info
22
 
23
  # Page configuration
 
645
  model, tokenizer = load_enhanced_model_m2max()
646
 
647
  # Get RAG instance
648
+ rag = get_supra_rag()
649
 
650
  # Generate response with RAG context
651
  response = rag.generate_response(prompt, model, tokenizer)
 
752
 
753
  # RAG Status
754
  try:
755
+ rag = get_supra_rag()
756
  rag_count = len(rag.collection.get()['ids'])
757
  st.markdown(f"""
758
  <div class="metric-card">
rag/rag.py CHANGED
@@ -157,20 +157,20 @@ class SupraRAG:
157
  st.warning("⚠️ No valid documents found in RAG data file")
158
 
159
  def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
160
- """Retrieve relevant context for a query with M2 Max optimizations."""
161
  try:
162
- # Limit query length for M2 Max efficiency
163
  if len(query) > 500:
164
  query = query[:500]
165
 
166
  results = self.collection.query(
167
  query_texts=[query],
168
- n_results=min(n_results, 5) # Limit results for M2 Max
169
  )
170
 
171
  context_docs = []
172
  for i, doc in enumerate(results['documents'][0]):
173
- # Truncate retrieved content for M2 Max memory efficiency
174
  content = doc
175
  if len(content) > 1500:
176
  content = content[:1500] + "..."
@@ -191,15 +191,15 @@ class SupraRAG:
191
  return []
192
 
193
  def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
194
- """Build enhanced prompt with RAG context and SUPRA facts optimized for M2 Max."""
195
  # Import SUPRA facts system
196
  from .supra_facts import build_supra_prompt, inject_facts_for_query
197
 
198
  # Extract RAG context chunks
199
  rag_context = None
200
  if context_docs:
201
- # Limit context length for M2 Max memory efficiency
202
- max_context_length = 2000 # Reduced for M2 Max
203
  context_text = ""
204
 
205
  for doc in context_docs:
@@ -270,11 +270,11 @@ class SupraRAG:
270
 
271
  # Global RAG instance with device-specific optimizations
272
  @st.cache_resource
273
- def get_supra_rag_m2max():
274
  """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
275
- return SupraRAGM2Max()
276
 
277
- # Backward compatibility
278
- def get_supra_rag():
279
  """Backward compatible function that returns device-optimized RAG."""
280
- return get_supra_rag_m2max()
 
157
  st.warning("⚠️ No valid documents found in RAG data file")
158
 
159
  def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
160
+ """Retrieve relevant context for a query with device optimizations."""
161
  try:
162
+ # Limit query length for efficiency
163
  if len(query) > 500:
164
  query = query[:500]
165
 
166
  results = self.collection.query(
167
  query_texts=[query],
168
+ n_results=min(n_results, 5) # Limit results for efficiency
169
  )
170
 
171
  context_docs = []
172
  for i, doc in enumerate(results['documents'][0]):
173
+ # Truncate retrieved content for memory efficiency
174
  content = doc
175
  if len(content) > 1500:
176
  content = content[:1500] + "..."
 
191
  return []
192
 
193
  def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
194
+ """Build enhanced prompt with RAG context and SUPRA facts with device optimizations."""
195
  # Import SUPRA facts system
196
  from .supra_facts import build_supra_prompt, inject_facts_for_query
197
 
198
  # Extract RAG context chunks
199
  rag_context = None
200
  if context_docs:
201
+ # Limit context length for memory efficiency
202
+ max_context_length = 2000 # Reduced for memory efficiency
203
  context_text = ""
204
 
205
  for doc in context_docs:
 
270
 
271
  # Global RAG instance with device-specific optimizations
272
  @st.cache_resource
273
+ def get_supra_rag():
274
  """Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
275
+ return SupraRAG()
276
 
277
+ # Backward compatibility (kept for compatibility with old imports)
278
+ def get_supra_rag_m2max():
279
  """Backward compatible function that returns device-optimized RAG."""
280
+ return get_supra_rag()