|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
from groq import Groq |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler |
|
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
|
import soundfile as sf |
|
|
from langchain_community.embeddings import OpenAIEmbeddings |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.chains import RetrievalQA |
|
|
from PIL import Image |
|
|
from decord import VideoReader, cpu |
|
|
from tavily import TavilyClient |
|
|
import requests |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
MODEL = 'llama3-groq-70b-8192-tool-use-preview' |
|
|
|
|
|
text_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, |
|
|
device_map="auto", torch_dtype=torch.bfloat16) |
|
|
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) |
|
|
|
|
|
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") |
|
|
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") |
|
|
|
|
|
|
|
|
base = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
repo = "ByteDance/SDXL-Lightning" |
|
|
ckpt = "sdxl_lightning_4step_unet.safetensors" |
|
|
|
|
|
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16) |
|
|
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) |
|
|
image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") |
|
|
image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing") |
|
|
|
|
|
|
|
|
tavily_client = TavilyClient(api_key="tvly-YOUR_API_KEY") |
|
|
|
|
|
|
|
|
def play_voice_output(response): |
|
|
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." |
|
|
input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') |
|
|
prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') |
|
|
generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) |
|
|
audio_arr = generation.cpu().numpy().squeeze() |
|
|
sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) |
|
|
return "output.wav" |
|
|
|
|
|
|
|
|
def numpy_calculate(code: str) -> str: |
|
|
try: |
|
|
local_dict = {} |
|
|
exec(code, {"np": np}, local_dict) |
|
|
result = local_dict.get("result", "No result found") |
|
|
return str(result) |
|
|
except Exception as e: |
|
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
|
def use_langchain_rag(file_name, file_content, query): |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
|
docs = text_splitter.create_documents([file_content]) |
|
|
|
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db") |
|
|
|
|
|
|
|
|
qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever()) |
|
|
|
|
|
|
|
|
return qa.run(query) |
|
|
|
|
|
|
|
|
def encode_video(video_path): |
|
|
MAX_NUM_FRAMES = 64 |
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
|
sample_fps = round(vr.get_avg_fps() / 1) |
|
|
frame_idx = [i for i in range(0, len(vr), sample_fps)] |
|
|
if len(frame_idx) > MAX_NUM_FRAMES: |
|
|
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) |
|
|
frames = vr.get_batch(frame_idx).asnumpy() |
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames] |
|
|
return frames |
|
|
|
|
|
|
|
|
def web_search(query): |
|
|
answer = tavily_client.qna_search(query=query) |
|
|
return answer |
|
|
|
|
|
|
|
|
def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False): |
|
|
|
|
|
if audio: |
|
|
transcription = client.audio.transcriptions.create( |
|
|
file=(audio.name, audio.read()), |
|
|
model="whisper-large-v3" |
|
|
) |
|
|
user_prompt = transcription.text |
|
|
|
|
|
|
|
|
if image: |
|
|
image = Image.open(image).convert('RGB') |
|
|
messages = [{"role": "user", "content": [image, user_prompt]}] |
|
|
response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer) |
|
|
return response |
|
|
|
|
|
|
|
|
if doc: |
|
|
file_content = doc.read().decode('utf-8') |
|
|
response = use_langchain_rag(doc.name, file_content, user_prompt) |
|
|
elif "calculate" in user_prompt.lower(): |
|
|
response = numpy_calculate(user_prompt) |
|
|
elif "generate" in user_prompt.lower() and ("image" in user_prompt.lower() or "picture" in user_prompt.lower()): |
|
|
response = image_pipe(prompt=user_prompt, num_inference_steps=20, guidance_scale=7.5) |
|
|
elif websearch: |
|
|
response = web_search(user_prompt) |
|
|
else: |
|
|
chat_completion = client.chat.completions.create( |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
model=MODEL, |
|
|
) |
|
|
response = chat_completion.choices[0].message.content |
|
|
|
|
|
return response |
|
|
|
|
|
@spaces.GPU() |
|
|
def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, voice_only=False, websearch=False): |
|
|
text_model.to(device='cuda', dtype=torch.bfloat16) |
|
|
tts_model.to("cuda") |
|
|
unet.to("cuda", torch.float16) |
|
|
image_pipe.to("cuda") |
|
|
|
|
|
response = handle_input(user_prompt, image=image, video=video, audio=audio, doc=doc, websearch=websearch) |
|
|
|
|
|
if voice_only: |
|
|
audio_file = play_voice_output(response) |
|
|
return response, audio_file |
|
|
else: |
|
|
return response, None |
|
|
|
|
|
|
|
|
def create_ui(): |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# AI Assistant") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") |
|
|
video_input = gr.Video(label="Upload a video", elem_id="video-icon") |
|
|
audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") |
|
|
doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") |
|
|
voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") |
|
|
websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") |
|
|
with gr.Column(scale=1): |
|
|
submit = gr.Button("Submit") |
|
|
|
|
|
output_label = gr.Label(label="Output") |
|
|
audio_output = gr.Audio(label="Audio Output", visible=False) |
|
|
|
|
|
submit.click( |
|
|
fn=main_interface, |
|
|
inputs=[user_prompt, image_input, video_input, audio_input, doc_input, voice_only_mode, websearch_mode], |
|
|
outputs=[output_label, audio_output] |
|
|
) |
|
|
|
|
|
|
|
|
voice_only_mode.change( |
|
|
lambda x: gr.update(visible=not x), |
|
|
inputs=voice_only_mode, |
|
|
outputs=[user_prompt, image_input, video_input, doc_input, websearch_mode, submit] |
|
|
) |
|
|
voice_only_mode.change( |
|
|
lambda x: gr.update(visible=x), |
|
|
inputs=voice_only_mode, |
|
|
outputs=[audio_input] |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = create_ui() |
|
|
demo.launch(inline=False) |
|
|
|