arjunvankani's picture
Create app.py
0236db7 verified
raw
history blame
3.81 kB
import gradio as gr
from transformers import pipeline
from diffusers import StableDiffusionPipeline
import torch
# --- Load NLP pipelines ---
clf = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
ner = pipeline("ner", model="dslim/bert-base-NER", aggregation_strategy="simple")
mlm = pipeline("fill-mask", model="bert-base-uncased")
qa = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
# --- Vision pipelines ---
img_clf = pipeline("image-classification", model="google/vit-base-patch16-224")
det = pipeline("object-detection", model="facebook/detr-resnet-50")
seg = pipeline("image-segmentation", model="facebook/mask2former-swin-large-coco")
# --- Diffusion model for text-to-image ---
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# --- Speech ---
asr = pipeline("automatic-speech-recognition", model="openai/whisper-small")
tts = pipeline("text-to-speech", model="espnet/kan-bayashi_ljspeech_vits")
# --- Functions ---
def classify_text(text):
return clf(text)
def ner_text(text):
return ner(text)
def fill_blank(text):
return mlm(text)
def answer_question(context, question):
return qa(question=question, context=context)
def classify_image(image):
return img_clf(image)
def detect_objects(image):
return det(image)
def segment_image(image):
return seg(image)
def generate_image(prompt):
image = sd_pipe(prompt).images[0]
return image
def transcribe(audio):
return asr(audio)["text"]
def speak_text(text):
audio = tts(text)
return (audio["sample_rate"], audio["audio"])
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# 🌍 Environmental AI Toolkit")
with gr.Tab("Sentence Classification"):
txt_in = gr.Textbox(label="Enter text")
txt_out = gr.JSON(label="Classification Result")
txt_in.submit(classify_text, txt_in, txt_out)
with gr.Tab("NER"):
ner_in = gr.Textbox(label="Enter text")
ner_out = gr.JSON(label="Entities")
ner_in.submit(ner_text, ner_in, ner_out)
with gr.Tab("Fill-in-the-Blank"):
mlm_in = gr.Textbox(label="Enter sentence with [MASK]")
mlm_out = gr.JSON(label="Predictions")
mlm_in.submit(fill_blank, mlm_in, mlm_out)
with gr.Tab("Question Answering"):
context = gr.Textbox(label="Context")
question = gr.Textbox(label="Question")
qa_out = gr.JSON(label="Answer")
gr.Button("Answer").click(answer_question, [context, question], qa_out)
with gr.Tab("Image Classification"):
img_in = gr.Image(type="pil")
img_out = gr.JSON(label="Labels")
img_in.upload(classify_image, img_in, img_out)
with gr.Tab("Object Detection"):
det_in = gr.Image(type="pil")
det_out = gr.JSON(label="Objects")
det_in.upload(detect_objects, det_in, det_out)
with gr.Tab("Segmentation"):
seg_in = gr.Image(type="pil")
seg_out = gr.JSON(label="Segments")
seg_in.upload(segment_image, seg_in, seg_out)
with gr.Tab("Image Generation"):
gen_in = gr.Textbox(label="Prompt")
gen_out = gr.Image(label="Generated Image")
gr.Button("Generate").click(generate_image, gen_in, gen_out)
with gr.Tab("Speech Recognition"):
audio_in = gr.Audio(type="filepath")
audio_out = gr.Textbox(label="Transcription")
audio_in.change(transcribe, audio_in, audio_out)
with gr.Tab("Text-to-Speech"):
tts_in = gr.Textbox(label="Text to Speak")
tts_out = gr.Audio(label="Generated Speech")
gr.Button("Speak").click(speak_text, tts_in, tts_out)
demo.launch()