Spaces:
Sleeping
Sleeping
Update knowledge_engine.py
Browse files- knowledge_engine.py +10 -7
knowledge_engine.py
CHANGED
|
@@ -3,7 +3,6 @@ import pickle
|
|
| 3 |
from typing import List, Dict, Any
|
| 4 |
from datetime import datetime
|
| 5 |
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
-
from sentence_transformers import SentenceTransformer
|
| 7 |
|
| 8 |
from config import Config
|
| 9 |
|
|
@@ -23,7 +22,8 @@ from langchain.chains import RetrievalQA
|
|
| 23 |
from langchain.prompts import PromptTemplate
|
| 24 |
from langchain.retrievers import BM25Retriever
|
| 25 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 26 |
-
from
|
|
|
|
| 27 |
|
| 28 |
class KnowledgeManager:
|
| 29 |
def __init__(self):
|
|
@@ -38,14 +38,17 @@ class KnowledgeManager:
|
|
| 38 |
|
| 39 |
def _init_llm(self):
|
| 40 |
print("[i] Using Hugging Face LLM")
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
temperature=0.1,
|
| 46 |
max_new_tokens=512,
|
| 47 |
-
|
| 48 |
)
|
|
|
|
| 49 |
|
| 50 |
def _init_retrievers(self):
|
| 51 |
faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
|
|
|
|
| 3 |
from typing import List, Dict, Any
|
| 4 |
from datetime import datetime
|
| 5 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 6 |
|
| 7 |
from config import Config
|
| 8 |
|
|
|
|
| 22 |
from langchain.prompts import PromptTemplate
|
| 23 |
from langchain.retrievers import BM25Retriever
|
| 24 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 25 |
+
from transformers import AutoTokenizer, pipeline
|
| 26 |
+
from langchain_community.llms import HuggingFacePipeline
|
| 27 |
|
| 28 |
class KnowledgeManager:
|
| 29 |
def __init__(self):
|
|
|
|
| 38 |
|
| 39 |
def _init_llm(self):
|
| 40 |
print("[i] Using Hugging Face LLM")
|
| 41 |
+
model_id = "tiiuae/falcon-7b-instruct"
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 43 |
+
pipe = pipeline(
|
| 44 |
+
"text-generation",
|
| 45 |
+
model=model_id,
|
| 46 |
+
tokenizer=tokenizer,
|
| 47 |
temperature=0.1,
|
| 48 |
max_new_tokens=512,
|
| 49 |
+
device_map="auto"
|
| 50 |
)
|
| 51 |
+
return HuggingFacePipeline(pipeline=pipe)
|
| 52 |
|
| 53 |
def _init_retrievers(self):
|
| 54 |
faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
|