Spaces:
Sleeping
Sleeping
Jan Biermeyer
commited on
Commit
·
3379400
1
Parent(s):
34fc1eb
cpu
Browse files- app.py +3 -3
- rag/rag.py +12 -12
app.py
CHANGED
|
@@ -17,7 +17,7 @@ import base64
|
|
| 17 |
# Add project root to path for imports
|
| 18 |
project_root = Path(__file__).parent
|
| 19 |
sys.path.insert(0, str(project_root))
|
| 20 |
-
from rag.
|
| 21 |
from rag.model_loader import load_enhanced_model_m2max, get_model_info
|
| 22 |
|
| 23 |
# Page configuration
|
|
@@ -645,7 +645,7 @@ def call_enhanced_model_with_rag(prompt: str) -> tuple[Optional[str], float]:
|
|
| 645 |
model, tokenizer = load_enhanced_model_m2max()
|
| 646 |
|
| 647 |
# Get RAG instance
|
| 648 |
-
rag =
|
| 649 |
|
| 650 |
# Generate response with RAG context
|
| 651 |
response = rag.generate_response(prompt, model, tokenizer)
|
|
@@ -752,7 +752,7 @@ def main():
|
|
| 752 |
|
| 753 |
# RAG Status
|
| 754 |
try:
|
| 755 |
-
rag =
|
| 756 |
rag_count = len(rag.collection.get()['ids'])
|
| 757 |
st.markdown(f"""
|
| 758 |
<div class="metric-card">
|
|
|
|
| 17 |
# Add project root to path for imports
|
| 18 |
project_root = Path(__file__).parent
|
| 19 |
sys.path.insert(0, str(project_root))
|
| 20 |
+
from rag.rag import get_supra_rag
|
| 21 |
from rag.model_loader import load_enhanced_model_m2max, get_model_info
|
| 22 |
|
| 23 |
# Page configuration
|
|
|
|
| 645 |
model, tokenizer = load_enhanced_model_m2max()
|
| 646 |
|
| 647 |
# Get RAG instance
|
| 648 |
+
rag = get_supra_rag()
|
| 649 |
|
| 650 |
# Generate response with RAG context
|
| 651 |
response = rag.generate_response(prompt, model, tokenizer)
|
|
|
|
| 752 |
|
| 753 |
# RAG Status
|
| 754 |
try:
|
| 755 |
+
rag = get_supra_rag()
|
| 756 |
rag_count = len(rag.collection.get()['ids'])
|
| 757 |
st.markdown(f"""
|
| 758 |
<div class="metric-card">
|
rag/rag.py
CHANGED
|
@@ -157,20 +157,20 @@ class SupraRAG:
|
|
| 157 |
st.warning("⚠️ No valid documents found in RAG data file")
|
| 158 |
|
| 159 |
def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
| 160 |
-
"""Retrieve relevant context for a query with
|
| 161 |
try:
|
| 162 |
-
# Limit query length for
|
| 163 |
if len(query) > 500:
|
| 164 |
query = query[:500]
|
| 165 |
|
| 166 |
results = self.collection.query(
|
| 167 |
query_texts=[query],
|
| 168 |
-
n_results=min(n_results, 5) # Limit results for
|
| 169 |
)
|
| 170 |
|
| 171 |
context_docs = []
|
| 172 |
for i, doc in enumerate(results['documents'][0]):
|
| 173 |
-
# Truncate retrieved content for
|
| 174 |
content = doc
|
| 175 |
if len(content) > 1500:
|
| 176 |
content = content[:1500] + "..."
|
|
@@ -191,15 +191,15 @@ class SupraRAG:
|
|
| 191 |
return []
|
| 192 |
|
| 193 |
def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
|
| 194 |
-
"""Build enhanced prompt with RAG context and SUPRA facts
|
| 195 |
# Import SUPRA facts system
|
| 196 |
from .supra_facts import build_supra_prompt, inject_facts_for_query
|
| 197 |
|
| 198 |
# Extract RAG context chunks
|
| 199 |
rag_context = None
|
| 200 |
if context_docs:
|
| 201 |
-
# Limit context length for
|
| 202 |
-
max_context_length = 2000 # Reduced for
|
| 203 |
context_text = ""
|
| 204 |
|
| 205 |
for doc in context_docs:
|
|
@@ -270,11 +270,11 @@ class SupraRAG:
|
|
| 270 |
|
| 271 |
# Global RAG instance with device-specific optimizations
|
| 272 |
@st.cache_resource
|
| 273 |
-
def
|
| 274 |
"""Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
|
| 275 |
-
return
|
| 276 |
|
| 277 |
-
# Backward compatibility
|
| 278 |
-
def
|
| 279 |
"""Backward compatible function that returns device-optimized RAG."""
|
| 280 |
-
return
|
|
|
|
| 157 |
st.warning("⚠️ No valid documents found in RAG data file")
|
| 158 |
|
| 159 |
def retrieve_context(self, query: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
| 160 |
+
"""Retrieve relevant context for a query with device optimizations."""
|
| 161 |
try:
|
| 162 |
+
# Limit query length for efficiency
|
| 163 |
if len(query) > 500:
|
| 164 |
query = query[:500]
|
| 165 |
|
| 166 |
results = self.collection.query(
|
| 167 |
query_texts=[query],
|
| 168 |
+
n_results=min(n_results, 5) # Limit results for efficiency
|
| 169 |
)
|
| 170 |
|
| 171 |
context_docs = []
|
| 172 |
for i, doc in enumerate(results['documents'][0]):
|
| 173 |
+
# Truncate retrieved content for memory efficiency
|
| 174 |
content = doc
|
| 175 |
if len(content) > 1500:
|
| 176 |
content = content[:1500] + "..."
|
|
|
|
| 191 |
return []
|
| 192 |
|
| 193 |
def build_enhanced_prompt(self, user_query: str, context_docs: List[Dict[str, Any]]) -> str:
|
| 194 |
+
"""Build enhanced prompt with RAG context and SUPRA facts with device optimizations."""
|
| 195 |
# Import SUPRA facts system
|
| 196 |
from .supra_facts import build_supra_prompt, inject_facts_for_query
|
| 197 |
|
| 198 |
# Extract RAG context chunks
|
| 199 |
rag_context = None
|
| 200 |
if context_docs:
|
| 201 |
+
# Limit context length for memory efficiency
|
| 202 |
+
max_context_length = 2000 # Reduced for memory efficiency
|
| 203 |
context_text = ""
|
| 204 |
|
| 205 |
for doc in context_docs:
|
|
|
|
| 270 |
|
| 271 |
# Global RAG instance with device-specific optimizations
|
| 272 |
@st.cache_resource
|
| 273 |
+
def get_supra_rag():
|
| 274 |
"""Get cached SUPRA RAG instance optimized for CPU/MPS/CUDA."""
|
| 275 |
+
return SupraRAG()
|
| 276 |
|
| 277 |
+
# Backward compatibility (kept for compatibility with old imports)
|
| 278 |
+
def get_supra_rag_m2max():
|
| 279 |
"""Backward compatible function that returns device-optimized RAG."""
|
| 280 |
+
return get_supra_rag()
|