Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,6 +14,11 @@ import sys
|
|
| 14 |
from llama_cpp import Llama
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Set page config first
|
| 18 |
st.set_page_config(
|
| 19 |
page_title="The Sport Chatbot",
|
|
@@ -27,7 +32,28 @@ logging.basicConfig(
|
|
| 27 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 28 |
handlers=[logging.StreamHandler(sys.stdout)]
|
| 29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def download_file_with_progress(url: str, filename: str):
|
| 32 |
"""Download a file with progress bar using requests"""
|
| 33 |
response = requests.get(url, stream=True)
|
|
@@ -156,7 +182,8 @@ class RAGPipeline:
|
|
| 156 |
self.retriever = SentenceTransformerRetriever()
|
| 157 |
self.documents = []
|
| 158 |
self.device = torch.device("cpu")
|
| 159 |
-
|
|
|
|
| 160 |
|
| 161 |
def preprocess_query(self, query: str) -> str:
|
| 162 |
"""Clean and prepare the query"""
|
|
|
|
| 14 |
from llama_cpp import Llama
|
| 15 |
from tqdm import tqdm
|
| 16 |
|
| 17 |
+
# At the top of your script
|
| 18 |
+
os.environ['LLAMA_CPP_THREADS'] = '4'
|
| 19 |
+
os.environ['LLAMA_CPP_BATCH_SIZE'] = '512'
|
| 20 |
+
os.environ['LLAMA_CPP_MODEL_PATH'] = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
|
| 21 |
+
|
| 22 |
# Set page config first
|
| 23 |
st.set_page_config(
|
| 24 |
page_title="The Sport Chatbot",
|
|
|
|
| 32 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 33 |
handlers=[logging.StreamHandler(sys.stdout)]
|
| 34 |
)
|
| 35 |
+
# Add this at the top level of your script, after imports
|
| 36 |
+
@st.cache_resource
|
| 37 |
+
def get_llama_model():
|
| 38 |
+
model_path = os.path.join("models", "mistral-7b-v0.1.Q4_K_M.gguf")
|
| 39 |
+
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
| 40 |
+
|
| 41 |
+
if not os.path.exists(model_path):
|
| 42 |
+
st.info("Downloading model... This may take a while.")
|
| 43 |
+
direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf"
|
| 44 |
+
download_file_with_progress(direct_url, model_path)
|
| 45 |
|
| 46 |
+
llm_config = {
|
| 47 |
+
"model_path": model_path,
|
| 48 |
+
"n_ctx": 2048,
|
| 49 |
+
"n_threads": 4,
|
| 50 |
+
"n_batch": 512,
|
| 51 |
+
"n_gpu_layers": 0,
|
| 52 |
+
"verbose": False,
|
| 53 |
+
"use_mlock": True
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return Llama(**llm_config)
|
| 57 |
def download_file_with_progress(url: str, filename: str):
|
| 58 |
"""Download a file with progress bar using requests"""
|
| 59 |
response = requests.get(url, stream=True)
|
|
|
|
| 182 |
self.retriever = SentenceTransformerRetriever()
|
| 183 |
self.documents = []
|
| 184 |
self.device = torch.device("cpu")
|
| 185 |
+
# Use the cached model directly
|
| 186 |
+
self.llm = get_llama_model()
|
| 187 |
|
| 188 |
def preprocess_query(self, query: str) -> str:
|
| 189 |
"""Clean and prepare the query"""
|