Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,10 @@ import soundfile as sf
|
|
| 11 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 12 |
from langchain.tools import BaseTool
|
| 13 |
from langchain_groq import ChatGroq
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from PIL import Image
|
| 15 |
from tavily import TavilyClient
|
| 16 |
import requests
|
|
@@ -130,28 +134,50 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False, document=
|
|
| 130 |
user_prompt = transcription.text
|
| 131 |
|
| 132 |
tools = [
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
]
|
| 136 |
|
| 137 |
# Add the web search tool only if websearch mode is enabled
|
| 138 |
if websearch:
|
| 139 |
-
tools.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# Add the document question answering tool only if a document is provided
|
| 142 |
if document:
|
| 143 |
-
tools.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
if image:
|
| 150 |
image = Image.open(image).convert('RGB')
|
| 151 |
messages = [{"role": "user", "content": [image, user_prompt]}]
|
| 152 |
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
|
| 153 |
else:
|
| 154 |
-
response =
|
| 155 |
|
| 156 |
return response
|
| 157 |
|
|
|
|
| 11 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 12 |
from langchain.tools import BaseTool
|
| 13 |
from langchain_groq import ChatGroq
|
| 14 |
+
from langchain.agents import AgentExecutor, initialize_agent, Tool
|
| 15 |
+
from langchain.agents import AgentType
|
| 16 |
+
from langchain_groq import ChatGroq
|
| 17 |
+
from langchain.prompts import PromptTemplate
|
| 18 |
from PIL import Image
|
| 19 |
from tavily import TavilyClient
|
| 20 |
import requests
|
|
|
|
| 134 |
user_prompt = transcription.text
|
| 135 |
|
| 136 |
tools = [
|
| 137 |
+
Tool(
|
| 138 |
+
name="Numpy",
|
| 139 |
+
func=NumpyCodeCalculator()._run,
|
| 140 |
+
description="Useful for performing numpy computations"
|
| 141 |
+
),
|
| 142 |
+
Tool(
|
| 143 |
+
name="Image",
|
| 144 |
+
func=ImageGeneration()._run,
|
| 145 |
+
description="Useful for generating images based on text descriptions"
|
| 146 |
+
),
|
| 147 |
]
|
| 148 |
|
| 149 |
# Add the web search tool only if websearch mode is enabled
|
| 150 |
if websearch:
|
| 151 |
+
tools.append(Tool(
|
| 152 |
+
name="Web",
|
| 153 |
+
func=WebSearch()._run,
|
| 154 |
+
description="Useful for searching the web for information"
|
| 155 |
+
))
|
| 156 |
|
| 157 |
# Add the document question answering tool only if a document is provided
|
| 158 |
if document:
|
| 159 |
+
tools.append(Tool(
|
| 160 |
+
name="Document",
|
| 161 |
+
func=DocumentQuestionAnswering(document)._run,
|
| 162 |
+
description="Useful for answering questions about a specific document"
|
| 163 |
+
))
|
| 164 |
|
| 165 |
llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
|
| 166 |
+
|
| 167 |
+
# Initialize the agent
|
| 168 |
+
agent = initialize_agent(
|
| 169 |
+
tools,
|
| 170 |
+
llm,
|
| 171 |
+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
| 172 |
+
verbose=True
|
| 173 |
+
)
|
| 174 |
|
| 175 |
if image:
|
| 176 |
image = Image.open(image).convert('RGB')
|
| 177 |
messages = [{"role": "user", "content": [image, user_prompt]}]
|
| 178 |
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
|
| 179 |
else:
|
| 180 |
+
response = agent.run(user_prompt)
|
| 181 |
|
| 182 |
return response
|
| 183 |
|