aamirhameed commited on
Commit
e1851fc
·
verified ·
1 Parent(s): cd3eb05

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +10 -7
knowledge_engine.py CHANGED
@@ -3,7 +3,6 @@ import pickle
3
  from typing import List, Dict, Any
4
  from datetime import datetime
5
  from concurrent.futures import ThreadPoolExecutor
6
- from sentence_transformers import SentenceTransformer
7
 
8
  from config import Config
9
 
@@ -23,7 +22,8 @@ from langchain.chains import RetrievalQA
23
  from langchain.prompts import PromptTemplate
24
  from langchain.retrievers import BM25Retriever
25
  from langchain_community.embeddings import HuggingFaceEmbeddings
26
- from langchain_community.llms import HuggingFaceEndpoint
 
27
 
28
  class KnowledgeManager:
29
  def __init__(self):
@@ -38,14 +38,17 @@ class KnowledgeManager:
38
 
39
  def _init_llm(self):
40
  print("[i] Using Hugging Face LLM")
41
- return HuggingFaceEndpoint(
42
- endpoint_url="https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct",
43
- task="text-generation",
44
- huggingfacehub_api_token=hf_token,
 
 
45
  temperature=0.1,
46
  max_new_tokens=512,
47
- do_sample=True
48
  )
 
49
 
50
  def _init_retrievers(self):
51
  faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
 
3
  from typing import List, Dict, Any
4
  from datetime import datetime
5
  from concurrent.futures import ThreadPoolExecutor
 
6
 
7
  from config import Config
8
 
 
22
  from langchain.prompts import PromptTemplate
23
  from langchain.retrievers import BM25Retriever
24
  from langchain_community.embeddings import HuggingFaceEmbeddings
25
+ from transformers import AutoTokenizer, pipeline
26
+ from langchain_community.llms import HuggingFacePipeline
27
 
28
  class KnowledgeManager:
29
  def __init__(self):
 
38
 
39
  def _init_llm(self):
40
  print("[i] Using Hugging Face LLM")
41
+ model_id = "tiiuae/falcon-7b-instruct"
42
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
43
+ pipe = pipeline(
44
+ "text-generation",
45
+ model=model_id,
46
+ tokenizer=tokenizer,
47
  temperature=0.1,
48
  max_new_tokens=512,
49
+ device_map="auto"
50
  )
51
+ return HuggingFacePipeline(pipeline=pipe)
52
 
53
  def _init_retrievers(self):
54
  faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"