|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
from groq import Groq |
|
|
import spaces |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from diffusers import StableDiffusion3Pipeline |
|
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
|
import soundfile as sf |
|
|
from llama_index.core.agent import ReActAgent |
|
|
from llama_index.core.tools import FunctionTool |
|
|
from llama_index.llms.groq import Groq |
|
|
from PIL import Image |
|
|
from tavily import TavilyClient |
|
|
import requests |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
from llama_index.core.chat_engine.types import AgentChatResponse |
|
|
from llama_index.core import VectorStoreIndex |
|
|
|
|
|
|
|
|
MODEL = 'llama3-groq-70b-8192-tool-use-preview' |
|
|
client = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
|
|
vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, |
|
|
device_map="auto", torch_dtype=torch.bfloat16) |
|
|
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) |
|
|
|
|
|
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") |
|
|
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") |
|
|
|
|
|
|
|
|
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) |
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
|
|
|
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API")) |
|
|
|
|
|
|
|
|
def play_voice_output(response): |
|
|
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." |
|
|
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') |
|
|
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') |
|
|
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) |
|
|
audio_arr = generation.cpu().numpy().squeeze() |
|
|
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) |
|
|
return "output.wav" |
|
|
|
|
|
|
|
|
def numpy_code_calculator(query): |
|
|
try: |
|
|
|
|
|
local_dict = {"np": np} |
|
|
exec(query, local_dict) |
|
|
result = local_dict.get("result", "No result found") |
|
|
return str(result) |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
def web_search(query): |
|
|
answer = tavily_client.qna_search(query=query) |
|
|
return answer |
|
|
|
|
|
|
|
|
def image_generation(query): |
|
|
image = pipe( |
|
|
query, |
|
|
negative_prompt="", |
|
|
num_inference_steps=15, |
|
|
guidance_scale=7.0, |
|
|
).images[0] |
|
|
image.save("output.jpg") |
|
|
return "output.jpg" |
|
|
|
|
|
|
|
|
def document_question_answering(query, docs): |
|
|
index = VectorStoreIndex.from_documents(docs) |
|
|
query_engine = index.as_query_engine(similarity_top_k=3) |
|
|
response = query_engine.query(query) |
|
|
return str(response) |
|
|
|
|
|
|
|
|
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None): |
|
|
if audio: |
|
|
if isinstance(audio, str): |
|
|
audio = open(audio, "rb") |
|
|
transcription = client.audio.transcriptions.create( |
|
|
file=(audio.name, audio.read()), |
|
|
model="whisper-large-v3" |
|
|
) |
|
|
user_prompt = transcription.text |
|
|
|
|
|
tools = [ |
|
|
FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"), |
|
|
FunctionTool.from_defaults(fn=image_generation, name="Image"), |
|
|
] |
|
|
|
|
|
|
|
|
if websearch: |
|
|
tools.append(FunctionTool.from_defaults(fn=web_search, name="Web")) |
|
|
|
|
|
|
|
|
if document: |
|
|
docs = LlamaParse(result_type="text").load_data(document) |
|
|
tools.append(FunctionTool.from_defaults(fn=document_question_answering, name="Document", docs=docs)) |
|
|
|
|
|
llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) |
|
|
agent = ReActAgent.from_tools(tools, llm=llm, verbose=True) |
|
|
|
|
|
if image: |
|
|
image = Image.open(image).convert('RGB') |
|
|
messages = [{"role": "user", "content": [image, user_prompt]}] |
|
|
response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer) |
|
|
else: |
|
|
response = agent.chat(user_prompt) |
|
|
|
|
|
|
|
|
if isinstance(response, AgentChatResponse): |
|
|
response = response.response |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
def create_ui(): |
|
|
with gr.Blocks(css=""" |
|
|
/* Overall Styling */ |
|
|
body { |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
background-color: #f4f4f4; |
|
|
margin: 0; |
|
|
padding: 0; |
|
|
color: #333; |
|
|
} |
|
|
|
|
|
/* Title Styling */ |
|
|
.gradio-container h1 { |
|
|
text-align: center; |
|
|
padding: 20px 0; |
|
|
background-color: #007bff; /* Example color */ |
|
|
color: white; |
|
|
} |
|
|
|
|
|
/* Input Area Styling */ |
|
|
.gradio-container .gr-row { |
|
|
display: flex; |
|
|
justify-content: space-around; |
|
|
align-items: center; |
|
|
padding: 20px; |
|
|
} |
|
|
|
|
|
.gradio-container .gr-column { |
|
|
flex: 1; |
|
|
margin: 0 10px; |
|
|
} |
|
|
|
|
|
/* Textbox Styling */ |
|
|
.gradio-container textarea { |
|
|
width: calc(100% - 20px); |
|
|
padding: 10px; |
|
|
border: 2px solid #ccc; |
|
|
border-radius: 5px; |
|
|
font-size: 16px; |
|
|
} |
|
|
|
|
|
/* Button Styling */ |
|
|
.gradio-container button { |
|
|
background-color: #007bff; /* Example color */ |
|
|
color: white; |
|
|
padding: 12px 20px; |
|
|
border: none; |
|
|
border-radius: 5px; |
|
|
cursor: pointer; |
|
|
font-size: 16px; |
|
|
transition: background-color 0.3s; |
|
|
} |
|
|
|
|
|
.gradio-container button:hover { |
|
|
background-color: #0056b3; /* Example darker color */ |
|
|
} |
|
|
|
|
|
/* Output Area Styling */ |
|
|
.gradio-container .output-area { |
|
|
padding: 20px; |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
/* Image Styling */ |
|
|
.gradio-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
border-radius: 5px; |
|
|
} |
|
|
""") as demo: |
|
|
gr.Markdown("# AI Assistant") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") |
|
|
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") |
|
|
document_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") |
|
|
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") |
|
|
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") |
|
|
with gr.Column(scale=1): |
|
|
submit = gr.Button("Submit") |
|
|
|
|
|
output_label = gr.Label(label="Output") |
|
|
audio_output = gr.Audio(label="Audio Output", visible=False) |
|
|
|
|
|
submit.click( |
|
|
fn=main_interface, |
|
|
inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input], |
|
|
outputs=[output_label, audio_output] |
|
|
) |
|
|
|
|
|
voice_only_mode.change( |
|
|
lambda x: gr.update(visible=not x), |
|
|
inputs=voice_only_mode, |
|
|
outputs=[user_prompt, image_input, websearch_mode, document_input, submit] |
|
|
) |
|
|
voice_only_mode.change( |
|
|
lambda x: gr.update(visible=x), |
|
|
inputs=voice_only_mode, |
|
|
outputs=[audio_input] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
@spaces.GPU() |
|
|
def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None): |
|
|
print("Starting main_interface function") |
|
|
vqa_model.to(device='cuda', dtype=torch.bfloat16) |
|
|
tts_model.to("cuda") |
|
|
pipe.to("cuda") |
|
|
|
|
|
print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}") |
|
|
|
|
|
try: |
|
|
response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document) |
|
|
print("handle_input function executed successfully") |
|
|
except Exception as e: |
|
|
print(f"Error in handle_input: {e}") |
|
|
response = "Error occurred during processing." |
|
|
|
|
|
if voice_only: |
|
|
try: |
|
|
audio_output = play_voice_output(response) |
|
|
print("play_voice_output function executed successfully") |
|
|
return "Response generated.", audio_output |
|
|
except Exception as e: |
|
|
print(f"Error in play_voice_output: {e}") |
|
|
return "Error occurred during voice output.", None |
|
|
else: |
|
|
return response, None |
|
|
|
|
|
|
|
|
|
|
|
demo = create_ui() |
|
|
demo.launch() |