File size: 4,014 Bytes
564d1d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.llms import AzureOpenAI, OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA, ConversationalRetrievalChain, RetrievalQAWithSourcesChain
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory

from langchain.chat_models import AzureChatOpenAI


import os
import openai
os.environ['CWD'] = os.getcwd()

# for testing
import src.constants as constants
# import constants 
os.environ['OPENAI_API_KEY'] = constants.AZURE_OPENAI_KEY_FR
os.environ['OPENAI_API_BASE'] = constants.AZURE_OPENAI_ENDPOINT_FR
os.environ['OPENAI_API_VERSION'] = "2023-05-15"
os.environ['OPENAI_API_TYPE'] = "azure"
# openai.api_type = "azure"
# openai.api_base = constants.AZURE_OPENAI_ENDPOINT_FR
# openai.api_version = "2023-05-15"
openai.api_key = constants.OPEN_AI_KEY

def get_document_key(doc):
    return doc.metadata['source'] + '_page_' + str(doc.metadata['page'])


import os
from typing import Optional

class PDFEmbeddings():
    def __init__(self, path: Optional[str] = None):
        self.path = path or os.path.join(os.environ['CWD'], 'archive')
        self.text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200)
        self.embeddings = OpenAIEmbeddings(deployment= constants.AZURE_ENGINE_NAME_US, chunk_size=1,
                                           openai_api_key= constants.AZURE_OPENAI_KEY_US,
                                           openai_api_base= constants.AZURE_OPENAI_ENDPOINT_US,
                                           openai_api_version= "2023-05-15",
                                           openai_api_type= "azure",)
        self.vectorstore = Chroma(persist_directory=constants.persistent_dir, embedding_function=self.embeddings)
        self.retriever = self.vectorstore.as_retriever(search_type = "similarity", search_kwags= {"k": 5})
        self.memory = ConversationBufferMemory(memory_key='pdf_memory', return_messages=True)

    def process_documents(self):
        # Load the documents and process them
        loader = PyPDFDirectoryLoader(self.path)
        documents = loader.load()
        chunks = self.text_splitter.split_documents(documents)
        self.vectorstore.add_documents(chunks)

    def search(self, query: str, chain_type: str = "stuff"):
        chain = RetrievalQA.from_chain_type(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR, temperature=0),
                                            retriever= self.retriever, chain_type= chain_type, return_source_documents= True)
        result = chain({"query": query})
        return result

    def conversational_search(self, query: str, chain_type: str = "stuff"):
        chain = ConversationalRetrievalChain.from_llm(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR),
                                                      retriever= self.retriever, memory= self.memory, chain_type= chain_type)
        result = chain({"question": query})
        return result['answer']

    def load_and_run_chain(self, query: str, chain_type: str = "stuff"):
        chain = load_qa_chain(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR), chain_type= chain_type)
        return chain.run(input_documents = self.retriever, question = query)

if __name__ == '__main__':
    pdf_embed = PDFEmbeddings()
    # pdf_embed.process_documents() # This takes a while, so we only do it once
    result = pdf_embed.search("Give me a list of short relevant queries to look for papers related to the topics of the papers in the source documents.")
    print("\n\n", result['result'], "\n")
    print("Source documents:")
    for doc in result['source_documents']:
        print(doc.metadata['source'])