Update app.py
Browse files
app.py
CHANGED
|
@@ -44,6 +44,7 @@ tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))
|
|
| 44 |
|
| 45 |
# Function to play voice output
|
| 46 |
def play_voice_output(response):
|
|
|
|
| 47 |
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
|
| 48 |
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
|
| 49 |
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
|
|
@@ -58,6 +59,7 @@ class NumpyCodeCalculator(Tool):
|
|
| 58 |
description = "Useful only for performing numerical computations, not for general searches"
|
| 59 |
|
| 60 |
def _run(self, query: str) -> str:
|
|
|
|
| 61 |
try:
|
| 62 |
local_dict = {"np": np}
|
| 63 |
exec(query, local_dict)
|
|
@@ -72,6 +74,7 @@ class WebSearch(Tool):
|
|
| 72 |
description = "Useful for advanced web searching beyond general information"
|
| 73 |
|
| 74 |
def _run(self, query: str) -> str:
|
|
|
|
| 75 |
answer = tavily_client.qna_search(query=query)
|
| 76 |
return answer
|
| 77 |
|
|
@@ -81,6 +84,7 @@ class ImageGeneration(Tool):
|
|
| 81 |
description = "Useful for generating images based on text descriptions"
|
| 82 |
|
| 83 |
def _run(self, query: str) -> str:
|
|
|
|
| 84 |
image = pipe(
|
| 85 |
query,
|
| 86 |
negative_prompt="",
|
|
@@ -101,6 +105,7 @@ class DocumentQuestionAnswering(Tool):
|
|
| 101 |
self.qa_chain = self._setup_qa_chain()
|
| 102 |
|
| 103 |
def _setup_qa_chain(self):
|
|
|
|
| 104 |
loader = TextLoader(self.document)
|
| 105 |
documents = loader.load()
|
| 106 |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
@@ -116,77 +121,73 @@ class DocumentQuestionAnswering(Tool):
|
|
| 116 |
return qa_chain
|
| 117 |
|
| 118 |
def _run(self, query: str) -> str:
|
|
|
|
| 119 |
response = self.qa_chain.run(query)
|
| 120 |
return str(response)
|
| 121 |
|
| 122 |
|
| 123 |
# Function to handle different input types and choose the right tool
|
| 124 |
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
name="Image",
|
| 129 |
-
func=ImageGeneration(), # Pass the class instance, not ImageGeneration()._run
|
| 130 |
-
description="Useful for generating images based on text descriptions"
|
| 131 |
-
),
|
| 132 |
-
]
|
| 133 |
-
|
| 134 |
-
# Add the numpy tool, but with a more specific description
|
| 135 |
-
tools.append(Tool(
|
| 136 |
-
name="Calculator",
|
| 137 |
-
func=NumpyCodeCalculator(), # Pass the class instance, not NumpyCodeCalculator()._run
|
| 138 |
-
description="Useful only for performing numerical computations, not for general searches"
|
| 139 |
-
))
|
| 140 |
-
|
| 141 |
-
# Add the web search tool only if websearch mode is enabled
|
| 142 |
-
if websearch:
|
| 143 |
-
tools.append(Tool(
|
| 144 |
-
name="Web",
|
| 145 |
-
func=WebSearch(), # Pass the class instance, not WebSearch()._run
|
| 146 |
-
description="Useful for advanced web searching beyond general information"
|
| 147 |
-
))
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
| 151 |
-
tools.append(Tool(
|
| 152 |
-
name="Document",
|
| 153 |
-
func=DocumentQuestionAnswering(document), # This is already correct
|
| 154 |
-
description="Useful for answering questions about a specific document"
|
| 155 |
-
))
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if tool.name.lower() in user_prompt.lower():
|
| 163 |
-
requires_tool = True
|
| 164 |
-
break
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
agent = initialize_agent(
|
| 169 |
tools,
|
| 170 |
llm,
|
| 171 |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
| 172 |
verbose=True
|
| 173 |
)
|
| 174 |
-
|
| 175 |
-
if image:
|
| 176 |
-
image = Image.open(image).convert('RGB')
|
| 177 |
-
messages = [{"role": "user", "content": [image, user_prompt]}]
|
| 178 |
-
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
|
| 179 |
-
elif audio:
|
| 180 |
-
transcription = client.audio.transcriptions.create(
|
| 181 |
-
file=(audio.name, audio.read()),
|
| 182 |
-
model="whisper-large-v3"
|
| 183 |
-
)
|
| 184 |
-
user_prompt = transcription.text
|
| 185 |
-
response = agent.run(user_prompt)
|
| 186 |
-
else:
|
| 187 |
-
response = agent.run(user_prompt)
|
| 188 |
else:
|
| 189 |
-
|
| 190 |
response = llm.call(query=user_prompt)
|
| 191 |
|
| 192 |
return response
|
|
|
|
| 44 |
|
| 45 |
# Function to play voice output
|
| 46 |
def play_voice_output(response):
|
| 47 |
+
print("Executing play_voice_output function")
|
| 48 |
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
|
| 49 |
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
|
| 50 |
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
|
|
|
|
| 59 |
description = "Useful only for performing numerical computations, not for general searches"
|
| 60 |
|
| 61 |
def _run(self, query: str) -> str:
|
| 62 |
+
print("Executing NumpyCodeCalculator tool")
|
| 63 |
try:
|
| 64 |
local_dict = {"np": np}
|
| 65 |
exec(query, local_dict)
|
|
|
|
| 74 |
description = "Useful for advanced web searching beyond general information"
|
| 75 |
|
| 76 |
def _run(self, query: str) -> str:
|
| 77 |
+
print("Executing WebSearch tool")
|
| 78 |
answer = tavily_client.qna_search(query=query)
|
| 79 |
return answer
|
| 80 |
|
|
|
|
| 84 |
description = "Useful for generating images based on text descriptions"
|
| 85 |
|
| 86 |
def _run(self, query: str) -> str:
|
| 87 |
+
print("Executing ImageGeneration tool")
|
| 88 |
image = pipe(
|
| 89 |
query,
|
| 90 |
negative_prompt="",
|
|
|
|
| 105 |
self.qa_chain = self._setup_qa_chain()
|
| 106 |
|
| 107 |
def _setup_qa_chain(self):
|
| 108 |
+
print("Setting up DocumentQuestionAnswering tool")
|
| 109 |
loader = TextLoader(self.document)
|
| 110 |
documents = loader.load()
|
| 111 |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
|
|
| 121 |
return qa_chain
|
| 122 |
|
| 123 |
def _run(self, query: str) -> str:
|
| 124 |
+
print("Executing DocumentQuestionAnswering tool")
|
| 125 |
response = self.qa_chain.run(query)
|
| 126 |
return str(response)
|
| 127 |
|
| 128 |
|
| 129 |
# Function to handle different input types and choose the right tool
|
| 130 |
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
|
| 131 |
+
print(f"Handling input: {user_prompt}")
|
| 132 |
|
| 133 |
+
# Initialize the LLM
|
| 134 |
+
llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
# Define the tools
|
| 137 |
+
tools = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
# Add Image Generation Tool
|
| 140 |
+
tools.append(ImageGeneration())
|
| 141 |
+
|
| 142 |
+
# Add Calculator Tool
|
| 143 |
+
tools.append(NumpyCodeCalculator())
|
| 144 |
|
| 145 |
+
# Add Web Search Tool if enabled
|
| 146 |
+
if websearch:
|
| 147 |
+
tools.append(WebSearch())
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# Add Document QA Tool if document is provided
|
| 150 |
+
if document:
|
| 151 |
+
tools.append(DocumentQuestionAnswering(document))
|
| 152 |
+
|
| 153 |
+
# Check if any tools are mentioned in the user prompt
|
| 154 |
+
requires_tool = any([tool.name.lower() in user_prompt.lower() for tool in tools])
|
| 155 |
+
|
| 156 |
+
# Handle different input scenarios
|
| 157 |
+
if image:
|
| 158 |
+
print("Processing image input")
|
| 159 |
+
image = Image.open(image).convert('RGB')
|
| 160 |
+
messages = [{"role": "user", "content": [image, user_prompt]}]
|
| 161 |
+
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
|
| 162 |
+
elif audio:
|
| 163 |
+
print("Processing audio input")
|
| 164 |
+
transcription = client.audio.transcriptions.create(
|
| 165 |
+
file=(audio.name, audio.read()),
|
| 166 |
+
model="whisper-large-v3"
|
| 167 |
+
)
|
| 168 |
+
user_prompt = transcription.text
|
| 169 |
+
# If tools are required, use an agent
|
| 170 |
+
if requires_tool:
|
| 171 |
+
agent = initialize_agent(
|
| 172 |
+
tools,
|
| 173 |
+
llm,
|
| 174 |
+
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
| 175 |
+
verbose=True
|
| 176 |
+
)
|
| 177 |
+
response = agent.run(user_prompt)
|
| 178 |
+
else:
|
| 179 |
+
response = llm.call(query=user_prompt)
|
| 180 |
+
elif requires_tool:
|
| 181 |
+
print("Using agent with tools")
|
| 182 |
agent = initialize_agent(
|
| 183 |
tools,
|
| 184 |
llm,
|
| 185 |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
| 186 |
verbose=True
|
| 187 |
)
|
| 188 |
+
response = agent.run(user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
else:
|
| 190 |
+
print("Using LLM directly")
|
| 191 |
response = llm.call(query=user_prompt)
|
| 192 |
|
| 193 |
return response
|