KaiquanMah commited on
Commit
18fc979
Β·
verified Β·
1 Parent(s): 8aafb93

Retriever - LangGraph version

Browse files
Files changed (1) hide show
  1. retriever.py +24 -22
retriever.py CHANGED
@@ -1,33 +1,35 @@
1
- from smolagents import Tool
 
2
  from langchain_community.retrievers import BM25Retriever
3
- from langchain.docstore.document import Document
 
4
  import datasets
 
 
5
 
 
 
 
 
6
 
7
- class GuestInfoRetrieverTool(Tool):
8
- name = "guest_info_retriever"
9
- description = "Retrieves detailed information about gala guests based on their name or relation."
10
- inputs = {
11
- "query": {
12
- "type": "string",
13
- "description": "The name or relation of the guest you want information about."
14
- }
15
- }
16
- output_type = "string"
17
 
18
- def __init__(self, docs):
19
- self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
 
 
22
 
23
- def forward(self, query: str):
24
- results = self.retriever.get_relevant_documents(query)
25
- if results:
26
- return "\n\n".join([doc.page_content for doc in results[:3]])
27
- else:
28
- return "No matching guest information found."
29
 
30
 
 
31
  def load_guest_dataset():
32
  # Load the dataset
33
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
 
1
+ # retriever tool
2
+ from langchain.tools import Tool
3
  from langchain_community.retrievers import BM25Retriever
4
+
5
+ # load dataset
6
  import datasets
7
+ from langchain.docstore.document import Document
8
+
9
 
10
+ ###########################
11
+ # Retriever - changed to LangGraph version
12
+ ###########################
13
+ bm25_retriever = BM25Retriever.from_documents(docs)
14
 
15
+ def extract_text(query: str) -> str:
16
+ """Retrieves detailed information about gala guests based on their name or relation."""
17
+ results = bm25_retriever.invoke(query)
18
+ if results:
19
+ return "\n\n".join([doc.page_content for doc in results[:3]])
20
+ else:
21
+ return "No matching guest information found."
 
 
 
22
 
23
+ guest_info_tool = Tool(
24
+ name="guest_info_retriever",
25
+ func=extract_text,
26
+ description="Retrieves detailed information about gala guests based on their name or relation."
27
+ )
28
+ ###########################
29
 
 
 
 
 
 
 
30
 
31
 
32
+ # no change from smolagents to LangGraph
33
  def load_guest_dataset():
34
  # Load the dataset
35
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")