eternalGenius's picture
Create app.py
d6c0bdd verified
raw
history blame
18.9 kB
from __future__ import annotations
import os
from typing import List, Dict, Any
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import (
pipeline,
BlipForImageTextRetrieval,
AutoProcessor,
VisionEncoderDecoderModel,
ViTImageProcessor,
AutoTokenizer,
ViltForQuestionAnswering,
ViltProcessor,
CLIPModel,
)
from transformers.utils import logging as hf_logging
# -------------------------------------------------------------------------
# Global config
# -------------------------------------------------------------------------
hf_logging.set_verbosity_error()
DEVICE = 0 if torch.cuda.is_available() else -1
DEVICE_TORCH = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------------------------------------------------------
# Helper functions (image rendering)
# -------------------------------------------------------------------------
def _ensure_rgb(img: Image.Image) -> Image.Image:
if img.mode != "RGB":
return img.convert("RGB")
return img
def render_results_in_image(
pil_image: Image.Image,
detections: List[Dict[str, Any]],
score_threshold: float = 0.5,
) -> Image.Image:
img = _ensure_rgb(pil_image).copy()
draw = ImageDraw.Draw(img)
W, H = img.size
try:
font = ImageFont.load_default()
except Exception:
font = None
for det in detections:
score = float(det.get("score", 0.0))
if score < score_threshold:
continue
lbl = str(det.get("label", ""))
box = det.get("box", {})
x1 = box.get("xmin", 0)
y1 = box.get("ymin", 0)
x2 = box.get("xmax", 0)
y2 = box.get("ymax", 0)
# clamp to bounds
x1 = max(0, min(W, x1))
x2 = max(0, min(W, x2))
y1 = max(0, min(H, y1))
y2 = max(0, min(H, y2))
draw.rectangle([(x1, y1), (x2, y2)], outline=(0, 255, 0), width=3)
text = f"{lbl} {score:.2f}"
if font is not None:
bbox = draw.textbbox((0, 0), text, font=font)
else:
bbox = draw.textbbox((0, 0), text)
tw = bbox[2] - bbox[0]
th = bbox[3] - bbox[1]
pad = 2
tx2 = min(x1 + tw + 2 * pad, W)
ty2 = min(y1 + th + 2 * pad, H)
draw.rectangle([(x1, y1), (tx2, ty2)], fill=(0, 255, 0))
draw.text((x1 + pad, y1 + pad), text, fill=(0, 0, 0), font=font)
return img
def show_masks_on_image(pil_image: Image.Image, masks: list[np.ndarray]) -> Image.Image:
img = pil_image.convert("RGBA")
overlay = Image.new("RGBA", img.size)
for mask_np in masks:
mask_uint8 = (mask_np * 255).astype("uint8")
mask_img = Image.fromarray(mask_uint8).resize(img.size).convert("L")
color = (255, 0, 0, 100)
colored = Image.new("RGBA", img.size, color)
overlay = Image.composite(colored, overlay, mask_img)
combined = Image.alpha_composite(img, overlay)
return combined
# -------------------------------------------------------------------------
# Load models / pipelines
# -------------------------------------------------------------------------
# Object Detection
od_pipe = pipeline(
"object-detection",
model="hustvl/yolos-tiny",
device=DEVICE,
)
# Image Segmentation
segmentation_pipe = pipeline(
"mask-generation",
model="Zigeng/SlimSAM-uniform-50",
device=DEVICE,
)
# Image Retrieval (BLIP ITM)
retrieval_model_name = "Salesforce/blip-itm-base-coco"
retrieval_model = BlipForImageTextRetrieval.from_pretrained(retrieval_model_name).to(DEVICE_TORCH)
retrieval_processor = AutoProcessor.from_pretrained(retrieval_model_name)
# Image Captioning
caption_model_name = "nlpconnect/vit-gpt2-image-captioning"
caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model_name).to(DEVICE_TORCH)
caption_processor = ViTImageProcessor.from_pretrained(caption_model_name)
caption_tokenizer = AutoTokenizer.from_pretrained(caption_model_name)
# Speech Recognition – EN & Multilingual (RU, etc.)
asr_en = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
device=DEVICE,
)
asr_multi = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
device=DEVICE,
)
# Text to Speech – EN & RU
tts_en = pipeline(
task="text-to-speech",
model="facebook/mms-tts-eng",
)
tts_ru = pipeline(
task="text-to-speech",
model="facebook/mms-tts-rus",
)
# NLP pipelines
sentiment_pipe = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=DEVICE,
)
summarization_pipe = pipeline(
"summarization",
model="facebook/bart-large-cnn",
device=DEVICE,
)
translation_en_ru_pipe = pipeline(
"translation_en_to_ru",
model="Helsinki-NLP/opus-mt-en-ru",
device=DEVICE,
)
translation_ru_en_pipe = pipeline(
"translation_ru_to_en",
model="Helsinki-NLP/opus-mt-ru-en",
device=DEVICE,
)
qa_pipe = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
device=DEVICE,
)
# -------------------------------------------------------------------------
# Tab 1: Object Detection
# -------------------------------------------------------------------------
def od_predict(pil_image: Image.Image, score_threshold: float) -> Image.Image:
if pil_image is None:
raise gr.Error("Please upload an image.")
outputs = od_pipe(pil_image)
rendered = render_results_in_image(pil_image, outputs, score_threshold=score_threshold)
return rendered
def build_object_detection_tab() -> None:
with gr.TabItem("1. Object Detection"):
gr.Markdown(
"""
### 🧭 Object Detection
Модель: `hustvl/yolos-tiny`
Загрузите изображение — и модель найдёт на нём объекты и отрисует bounding boxes.
"""
)
with gr.Row():
inp = gr.Image(label="Input image", type="pil")
out = gr.Image(label="Detections", type="pil")
with gr.Row():
thr = gr.Slider(
0.0,
1.0,
value=0.5,
step=0.01,
label="Score threshold",
)
run = gr.Button("Detect")
run.click(fn=od_predict, inputs=[inp, thr], outputs=out)
# -------------------------------------------------------------------------
# Tab 2: Image Segmentation
# -------------------------------------------------------------------------
def sam_predict(pil_image: Image.Image) -> Image.Image:
if pil_image is None:
raise gr.Error("Upload an image.")
output = segmentation_pipe(pil_image, points_per_batch=32)
masks = output.get("masks", [])
if isinstance(masks, dict):
# в некоторых моделях может быть dict / np.ndarray
masks = [masks]
img = show_masks_on_image(pil_image, masks)
return img
def build_segmentation_tab() -> None:
with gr.TabItem("2. Image Segmentation"):
gr.Markdown(
"""
### 🎨 Image Segmentation
Модель: `Zigeng/SlimSAM-uniform-50`
Сегментация изображения — поверх изображения накладываются полупрозрачные маски.
"""
)
with gr.Row():
inp = gr.Image(label="Input image", type="pil")
out = gr.Image(label="Segmentation overlay", type="pil")
run = gr.Button("Segment")
run.click(fn=sam_predict, inputs=inp, outputs=out)
# -------------------------------------------------------------------------
# Tab 3: Image Retrieval (Image–Text Matching)
# -------------------------------------------------------------------------
def check_match(pil_image: Image.Image, text: str) -> str:
if pil_image is None or not text:
raise gr.Error("Нужна картинка и текст.")
inputs = retrieval_processor(images=pil_image, text=text, return_tensors="pt").to(DEVICE_TORCH)
with torch.no_grad():
outputs = retrieval_model(**inputs)
# BLIP ITM обычно возвращает logits_itm (batch, 2)
logits = outputs.logits_itm if hasattr(outputs, "logits_itm") else outputs[0]
probs = torch.nn.functional.softmax(logits, dim=1)
prob = float(probs[0][1]) # вероятность "matched"
return f"Match probability: {prob:.4f}"
def build_retrieval_tab() -> None:
with gr.TabItem("3. Image Retrieval"):
gr.Markdown(
"""
### 🔎 Image–Text Matching (BLIP ITM)
Модель: `Salesforce/blip-itm-base-coco`
Оценка, насколько текстовое описание соответствует изображению.
"""
)
with gr.Row():
inp_image = gr.Image(type="pil", label="Image")
inp_text = gr.Textbox(label="Text query", placeholder="Describe the image...")
score_out = gr.Textbox(label="Match score")
run = gr.Button("Check match")
run.click(fn=check_match, inputs=[inp_image, inp_text], outputs=score_out)
# -------------------------------------------------------------------------
# Tab 4: Image Captioning
# -------------------------------------------------------------------------
def caption_predict(pil_image: Image.Image) -> str:
if pil_image is None:
raise gr.Error("Upload an image.")
pixel_values = caption_processor(images=pil_image, return_tensors="pt").pixel_values.to(DEVICE_TORCH)
with torch.no_grad():
output_ids = caption_model.generate(pixel_values, max_length=50)
caption = caption_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption
def build_captioning_tab() -> None:
with gr.TabItem("4. Image Captioning"):
gr.Markdown(
"""
### 📝 Image Captioning
Модель: `nlpconnect/vit-gpt2-image-captioning`
Генерация текстового описания изображения.
"""
)
with gr.Row():
inp = gr.Image(type="pil", label="Input image")
out = gr.Textbox(label="Generated caption")
run = gr.Button("Generate caption")
run.click(fn=caption_predict, inputs=inp, outputs=out)
# -------------------------------------------------------------------------
# Tab 5: Speech Recognition (ASR)
# -------------------------------------------------------------------------
def transcribe_audio(filepath: str, lang_model: str) -> str:
if filepath is None:
raise gr.Error("No audio found. Try again.")
if lang_model == "English (whisper-tiny.en)":
result = asr_en(filepath)
else:
result = asr_multi(filepath)
text = result["text"] if isinstance(result, dict) else result
return text
def build_speech_tab() -> None:
with gr.TabItem("5. Speech Recognition"):
gr.Markdown(
"""
### 🎙️ Speech Recognition
- **English**: `openai/whisper-tiny.en`
- **Multilingual (incl. Russian)**: `openai/whisper-tiny`
Можно записать голос через микрофон или загрузить аудио-файл.
"""
)
lang_selector = gr.Radio(
choices=["English (whisper-tiny.en)", "Multilingual / Russian (whisper-tiny)"],
value="English (whisper-tiny.en)",
label="Model / language",
)
with gr.Tabs():
with gr.Tab("Microphone"):
mic_input = gr.Audio(sources="microphone", type="filepath", label="Record audio")
mic_out = gr.Textbox(label="Transcription", lines=3)
mic_btn = gr.Button("Transcribe")
mic_btn.click(
fn=transcribe_audio,
inputs=[mic_input, lang_selector],
outputs=mic_out,
)
with gr.Tab("Upload file"):
file_input = gr.Audio(sources="upload", type="filepath", label="Upload audio")
file_out = gr.Textbox(label="Transcription", lines=3)
file_btn = gr.Button("Transcribe")
file_btn.click(
fn=transcribe_audio,
inputs=[file_input, lang_selector],
outputs=file_out,
)
# -------------------------------------------------------------------------
# Tab 6: Text to Speech (TTS)
# -------------------------------------------------------------------------
def tts_predict(text: str, lang_model: str):
if not text or text.strip() == "":
raise gr.Error("Введите текст.")
if lang_model == "English (mms-tts-eng)":
result = tts_en(text)
else:
result = tts_ru(text)
audio = result["audio"]
sr = result["sampling_rate"]
# MMS TTS возвращает numpy-массив или список массивов
if isinstance(audio, list):
audio = audio[0]
return sr, audio
def build_tts_tab() -> None:
with gr.TabItem("6. Text to Speech"):
gr.Markdown(
"""
### 🔊 Text to Speech
- **English**: `facebook/mms-tts-eng`
- **Russian**: `facebook/mms-tts-rus`
"""
)
lang_selector = gr.Radio(
choices=["English (mms-tts-eng)", "Russian (mms-tts-rus)"],
value="Russian (mms-tts-rus)",
label="Voice / language",
)
inp = gr.Textbox(
label="Input text",
lines=4,
placeholder="Введите текст (для русского) или введите текст на английском...",
)
out = gr.Audio(label="Generated speech", type="numpy")
run = gr.Button("Speak")
run.click(fn=tts_predict, inputs=[inp, lang_selector], outputs=out)
# -------------------------------------------------------------------------
# Tab 7: Natural Language Processing
# -------------------------------------------------------------------------
def nlp_run(task: str, text: str, context: str, question: str) -> str:
text = text or ""
context = context or ""
question = question or ""
if task == "Sentiment Analysis":
if not text:
raise gr.Error("Enter text for sentiment analysis.")
result = sentiment_pipe(text)[0]
return f"Label: {result['label']}, score: {result['score']:.4f}"
if task == "Summarization":
if not text:
raise gr.Error("Enter text to summarize.")
result = summarization_pipe(text, max_length=120, min_length=30, do_sample=False)[0]
return result["summary_text"]
if task == "Translation EN → RU":
if not text:
raise gr.Error("Enter English text for translation.")
result = translation_en_ru_pipe(text)[0]
return result["translation_text"]
if task == "Translation RU → EN":
if not text:
raise gr.Error("Введите русский текст для перевода.")
result = translation_ru_en_pipe(text)[0]
return result["translation_text"]
if task == "Question Answering (QA)":
if not context or not question:
raise gr.Error("Provide both context and question for QA.")
result = qa_pipe(question=question, context=context)
return f"Answer: {result.get('answer', '')}\nScore: {result.get('score', 0.0):.4f}"
raise gr.Error("Unknown task")
def build_nlp_tab() -> None:
with gr.TabItem("7. Natural Language Processing"):
gr.Markdown(
"""
### 📚 Natural Language Processing
Доступные задачи:
- **Sentiment Analysis** (EN)
- **Summarization** (EN)
- **Translation EN → RU / RU → EN**
- **Question Answering (QA)**
"""
)
task_selector = gr.Radio(
choices=[
"Sentiment Analysis",
"Summarization",
"Translation EN → RU",
"Translation RU → EN",
"Question Answering (QA)",
],
value="Sentiment Analysis",
label="Task",
)
with gr.Row():
text_input = gr.Textbox(
label="Text (for sentiment / summarization / translation)",
lines=6,
placeholder="Введите или вставьте текст здесь...",
)
with gr.Accordion("Context & Question (for QA)", open=False):
context_input = gr.Textbox(
label="Context",
lines=6,
placeholder="Текст-контекст для вопроса...",
)
question_input = gr.Textbox(
label="Question",
lines=2,
placeholder="Ваш вопрос по контексту...",
)
out = gr.Textbox(label="Result", lines=8)
run = gr.Button("Run")
run.click(
fn=nlp_run,
inputs=[task_selector, text_input, context_input, question_input],
outputs=out,
)
# -------------------------------------------------------------------------
# Build full app
# -------------------------------------------------------------------------
CSS = """
#root .gradio-container {
max-width: 1100px;
margin: auto;
}
"""
def build_app() -> gr.Blocks:
with gr.Blocks(css=CSS, title="Multimodal Playground – Vision, Audio & NLP") as demo:
gr.Markdown(
"""
# 🎛️ Multimodal Playground — Vision, Audio & NLP
Демонстрация нескольких open-source моделей:
* Обнаружение объектов и сегментация изображений
* Image–Text Retrieval и Captioning
* Speech Recognition (RU/EN) и Text-to-Speech (RU/EN)
* Классические NLP-задачи
"""
)
with gr.Tabs():
build_object_detection_tab()
build_segmentation_tab()
build_retrieval_tab()
build_captioning_tab()
build_speech_tab()
build_tts_tab()
build_nlp_tab()
return demo
app = build_app()
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))