Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from typing import List, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import ( | |
| pipeline, | |
| BlipForImageTextRetrieval, | |
| AutoProcessor, | |
| VisionEncoderDecoderModel, | |
| ViTImageProcessor, | |
| AutoTokenizer, | |
| ViltForQuestionAnswering, | |
| ViltProcessor, | |
| CLIPModel, | |
| ) | |
| from transformers.utils import logging as hf_logging | |
| # ------------------------------------------------------------------------- | |
| # Global config | |
| # ------------------------------------------------------------------------- | |
| hf_logging.set_verbosity_error() | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| DEVICE_TORCH = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------------------------------------------------------- | |
| # Helper functions (image rendering) | |
| # ------------------------------------------------------------------------- | |
| def _ensure_rgb(img: Image.Image) -> Image.Image: | |
| if img.mode != "RGB": | |
| return img.convert("RGB") | |
| return img | |
| def render_results_in_image( | |
| pil_image: Image.Image, | |
| detections: List[Dict[str, Any]], | |
| score_threshold: float = 0.5, | |
| ) -> Image.Image: | |
| img = _ensure_rgb(pil_image).copy() | |
| draw = ImageDraw.Draw(img) | |
| W, H = img.size | |
| try: | |
| font = ImageFont.load_default() | |
| except Exception: | |
| font = None | |
| for det in detections: | |
| score = float(det.get("score", 0.0)) | |
| if score < score_threshold: | |
| continue | |
| lbl = str(det.get("label", "")) | |
| box = det.get("box", {}) | |
| x1 = box.get("xmin", 0) | |
| y1 = box.get("ymin", 0) | |
| x2 = box.get("xmax", 0) | |
| y2 = box.get("ymax", 0) | |
| # clamp to bounds | |
| x1 = max(0, min(W, x1)) | |
| x2 = max(0, min(W, x2)) | |
| y1 = max(0, min(H, y1)) | |
| y2 = max(0, min(H, y2)) | |
| draw.rectangle([(x1, y1), (x2, y2)], outline=(0, 255, 0), width=3) | |
| text = f"{lbl} {score:.2f}" | |
| if font is not None: | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| else: | |
| bbox = draw.textbbox((0, 0), text) | |
| tw = bbox[2] - bbox[0] | |
| th = bbox[3] - bbox[1] | |
| pad = 2 | |
| tx2 = min(x1 + tw + 2 * pad, W) | |
| ty2 = min(y1 + th + 2 * pad, H) | |
| draw.rectangle([(x1, y1), (tx2, ty2)], fill=(0, 255, 0)) | |
| draw.text((x1 + pad, y1 + pad), text, fill=(0, 0, 0), font=font) | |
| return img | |
| def show_masks_on_image(pil_image: Image.Image, masks: list[np.ndarray]) -> Image.Image: | |
| img = pil_image.convert("RGBA") | |
| overlay = Image.new("RGBA", img.size) | |
| for mask_np in masks: | |
| mask_uint8 = (mask_np * 255).astype("uint8") | |
| mask_img = Image.fromarray(mask_uint8).resize(img.size).convert("L") | |
| color = (255, 0, 0, 100) | |
| colored = Image.new("RGBA", img.size, color) | |
| overlay = Image.composite(colored, overlay, mask_img) | |
| combined = Image.alpha_composite(img, overlay) | |
| return combined | |
| # ------------------------------------------------------------------------- | |
| # Load models / pipelines | |
| # ------------------------------------------------------------------------- | |
| # Object Detection | |
| od_pipe = pipeline( | |
| "object-detection", | |
| model="hustvl/yolos-tiny", | |
| device=DEVICE, | |
| ) | |
| # Image Segmentation | |
| segmentation_pipe = pipeline( | |
| "mask-generation", | |
| model="Zigeng/SlimSAM-uniform-50", | |
| device=DEVICE, | |
| ) | |
| # Image Retrieval (BLIP ITM) | |
| retrieval_model_name = "Salesforce/blip-itm-base-coco" | |
| retrieval_model = BlipForImageTextRetrieval.from_pretrained(retrieval_model_name).to(DEVICE_TORCH) | |
| retrieval_processor = AutoProcessor.from_pretrained(retrieval_model_name) | |
| # Image Captioning | |
| caption_model_name = "nlpconnect/vit-gpt2-image-captioning" | |
| caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model_name).to(DEVICE_TORCH) | |
| caption_processor = ViTImageProcessor.from_pretrained(caption_model_name) | |
| caption_tokenizer = AutoTokenizer.from_pretrained(caption_model_name) | |
| # Speech Recognition – EN & Multilingual (RU, etc.) | |
| asr_en = pipeline( | |
| task="automatic-speech-recognition", | |
| model="openai/whisper-tiny.en", | |
| device=DEVICE, | |
| ) | |
| asr_multi = pipeline( | |
| task="automatic-speech-recognition", | |
| model="openai/whisper-tiny", | |
| device=DEVICE, | |
| ) | |
| # Text to Speech – EN & RU | |
| tts_en = pipeline( | |
| task="text-to-speech", | |
| model="facebook/mms-tts-eng", | |
| ) | |
| tts_ru = pipeline( | |
| task="text-to-speech", | |
| model="facebook/mms-tts-rus", | |
| ) | |
| # NLP pipelines | |
| sentiment_pipe = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=DEVICE, | |
| ) | |
| summarization_pipe = pipeline( | |
| "summarization", | |
| model="facebook/bart-large-cnn", | |
| device=DEVICE, | |
| ) | |
| translation_en_ru_pipe = pipeline( | |
| "translation_en_to_ru", | |
| model="Helsinki-NLP/opus-mt-en-ru", | |
| device=DEVICE, | |
| ) | |
| translation_ru_en_pipe = pipeline( | |
| "translation_ru_to_en", | |
| model="Helsinki-NLP/opus-mt-ru-en", | |
| device=DEVICE, | |
| ) | |
| qa_pipe = pipeline( | |
| "question-answering", | |
| model="deepset/roberta-base-squad2", | |
| device=DEVICE, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Tab 1: Object Detection | |
| # ------------------------------------------------------------------------- | |
| def od_predict(pil_image: Image.Image, score_threshold: float) -> Image.Image: | |
| if pil_image is None: | |
| raise gr.Error("Please upload an image.") | |
| outputs = od_pipe(pil_image) | |
| rendered = render_results_in_image(pil_image, outputs, score_threshold=score_threshold) | |
| return rendered | |
| def build_object_detection_tab() -> None: | |
| with gr.TabItem("1. Object Detection"): | |
| gr.Markdown( | |
| """ | |
| ### 🧭 Object Detection | |
| Модель: `hustvl/yolos-tiny` | |
| Загрузите изображение — и модель найдёт на нём объекты и отрисует bounding boxes. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(label="Input image", type="pil") | |
| out = gr.Image(label="Detections", type="pil") | |
| with gr.Row(): | |
| thr = gr.Slider( | |
| 0.0, | |
| 1.0, | |
| value=0.5, | |
| step=0.01, | |
| label="Score threshold", | |
| ) | |
| run = gr.Button("Detect") | |
| run.click(fn=od_predict, inputs=[inp, thr], outputs=out) | |
| # ------------------------------------------------------------------------- | |
| # Tab 2: Image Segmentation | |
| # ------------------------------------------------------------------------- | |
| def sam_predict(pil_image: Image.Image) -> Image.Image: | |
| if pil_image is None: | |
| raise gr.Error("Upload an image.") | |
| output = segmentation_pipe(pil_image, points_per_batch=32) | |
| masks = output.get("masks", []) | |
| if isinstance(masks, dict): | |
| # в некоторых моделях может быть dict / np.ndarray | |
| masks = [masks] | |
| img = show_masks_on_image(pil_image, masks) | |
| return img | |
| def build_segmentation_tab() -> None: | |
| with gr.TabItem("2. Image Segmentation"): | |
| gr.Markdown( | |
| """ | |
| ### 🎨 Image Segmentation | |
| Модель: `Zigeng/SlimSAM-uniform-50` | |
| Сегментация изображения — поверх изображения накладываются полупрозрачные маски. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(label="Input image", type="pil") | |
| out = gr.Image(label="Segmentation overlay", type="pil") | |
| run = gr.Button("Segment") | |
| run.click(fn=sam_predict, inputs=inp, outputs=out) | |
| # ------------------------------------------------------------------------- | |
| # Tab 3: Image Retrieval (Image–Text Matching) | |
| # ------------------------------------------------------------------------- | |
| def check_match(pil_image: Image.Image, text: str) -> str: | |
| if pil_image is None or not text: | |
| raise gr.Error("Нужна картинка и текст.") | |
| inputs = retrieval_processor(images=pil_image, text=text, return_tensors="pt").to(DEVICE_TORCH) | |
| with torch.no_grad(): | |
| outputs = retrieval_model(**inputs) | |
| # BLIP ITM обычно возвращает logits_itm (batch, 2) | |
| logits = outputs.logits_itm if hasattr(outputs, "logits_itm") else outputs[0] | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| prob = float(probs[0][1]) # вероятность "matched" | |
| return f"Match probability: {prob:.4f}" | |
| def build_retrieval_tab() -> None: | |
| with gr.TabItem("3. Image Retrieval"): | |
| gr.Markdown( | |
| """ | |
| ### 🔎 Image–Text Matching (BLIP ITM) | |
| Модель: `Salesforce/blip-itm-base-coco` | |
| Оценка, насколько текстовое описание соответствует изображению. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp_image = gr.Image(type="pil", label="Image") | |
| inp_text = gr.Textbox(label="Text query", placeholder="Describe the image...") | |
| score_out = gr.Textbox(label="Match score") | |
| run = gr.Button("Check match") | |
| run.click(fn=check_match, inputs=[inp_image, inp_text], outputs=score_out) | |
| # ------------------------------------------------------------------------- | |
| # Tab 4: Image Captioning | |
| # ------------------------------------------------------------------------- | |
| def caption_predict(pil_image: Image.Image) -> str: | |
| if pil_image is None: | |
| raise gr.Error("Upload an image.") | |
| pixel_values = caption_processor(images=pil_image, return_tensors="pt").pixel_values.to(DEVICE_TORCH) | |
| with torch.no_grad(): | |
| output_ids = caption_model.generate(pixel_values, max_length=50) | |
| caption = caption_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return caption | |
| def build_captioning_tab() -> None: | |
| with gr.TabItem("4. Image Captioning"): | |
| gr.Markdown( | |
| """ | |
| ### 📝 Image Captioning | |
| Модель: `nlpconnect/vit-gpt2-image-captioning` | |
| Генерация текстового описания изображения. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Input image") | |
| out = gr.Textbox(label="Generated caption") | |
| run = gr.Button("Generate caption") | |
| run.click(fn=caption_predict, inputs=inp, outputs=out) | |
| # ------------------------------------------------------------------------- | |
| # Tab 5: Speech Recognition (ASR) | |
| # ------------------------------------------------------------------------- | |
| def transcribe_audio(filepath: str, lang_model: str) -> str: | |
| if filepath is None: | |
| raise gr.Error("No audio found. Try again.") | |
| if lang_model == "English (whisper-tiny.en)": | |
| result = asr_en(filepath) | |
| else: | |
| result = asr_multi(filepath) | |
| text = result["text"] if isinstance(result, dict) else result | |
| return text | |
| def build_speech_tab() -> None: | |
| with gr.TabItem("5. Speech Recognition"): | |
| gr.Markdown( | |
| """ | |
| ### 🎙️ Speech Recognition | |
| - **English**: `openai/whisper-tiny.en` | |
| - **Multilingual (incl. Russian)**: `openai/whisper-tiny` | |
| Можно записать голос через микрофон или загрузить аудио-файл. | |
| """ | |
| ) | |
| lang_selector = gr.Radio( | |
| choices=["English (whisper-tiny.en)", "Multilingual / Russian (whisper-tiny)"], | |
| value="English (whisper-tiny.en)", | |
| label="Model / language", | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Microphone"): | |
| mic_input = gr.Audio(sources="microphone", type="filepath", label="Record audio") | |
| mic_out = gr.Textbox(label="Transcription", lines=3) | |
| mic_btn = gr.Button("Transcribe") | |
| mic_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[mic_input, lang_selector], | |
| outputs=mic_out, | |
| ) | |
| with gr.Tab("Upload file"): | |
| file_input = gr.Audio(sources="upload", type="filepath", label="Upload audio") | |
| file_out = gr.Textbox(label="Transcription", lines=3) | |
| file_btn = gr.Button("Transcribe") | |
| file_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[file_input, lang_selector], | |
| outputs=file_out, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Tab 6: Text to Speech (TTS) | |
| # ------------------------------------------------------------------------- | |
| def tts_predict(text: str, lang_model: str): | |
| if not text or text.strip() == "": | |
| raise gr.Error("Введите текст.") | |
| if lang_model == "English (mms-tts-eng)": | |
| result = tts_en(text) | |
| else: | |
| result = tts_ru(text) | |
| audio = result["audio"] | |
| sr = result["sampling_rate"] | |
| # MMS TTS возвращает numpy-массив или список массивов | |
| if isinstance(audio, list): | |
| audio = audio[0] | |
| return sr, audio | |
| def build_tts_tab() -> None: | |
| with gr.TabItem("6. Text to Speech"): | |
| gr.Markdown( | |
| """ | |
| ### 🔊 Text to Speech | |
| - **English**: `facebook/mms-tts-eng` | |
| - **Russian**: `facebook/mms-tts-rus` | |
| """ | |
| ) | |
| lang_selector = gr.Radio( | |
| choices=["English (mms-tts-eng)", "Russian (mms-tts-rus)"], | |
| value="Russian (mms-tts-rus)", | |
| label="Voice / language", | |
| ) | |
| inp = gr.Textbox( | |
| label="Input text", | |
| lines=4, | |
| placeholder="Введите текст (для русского) или введите текст на английском...", | |
| ) | |
| out = gr.Audio(label="Generated speech", type="numpy") | |
| run = gr.Button("Speak") | |
| run.click(fn=tts_predict, inputs=[inp, lang_selector], outputs=out) | |
| # ------------------------------------------------------------------------- | |
| # Tab 7: Natural Language Processing | |
| # ------------------------------------------------------------------------- | |
| def nlp_run(task: str, text: str, context: str, question: str) -> str: | |
| text = text or "" | |
| context = context or "" | |
| question = question or "" | |
| if task == "Sentiment Analysis": | |
| if not text: | |
| raise gr.Error("Enter text for sentiment analysis.") | |
| result = sentiment_pipe(text)[0] | |
| return f"Label: {result['label']}, score: {result['score']:.4f}" | |
| if task == "Summarization": | |
| if not text: | |
| raise gr.Error("Enter text to summarize.") | |
| result = summarization_pipe(text, max_length=120, min_length=30, do_sample=False)[0] | |
| return result["summary_text"] | |
| if task == "Translation EN → RU": | |
| if not text: | |
| raise gr.Error("Enter English text for translation.") | |
| result = translation_en_ru_pipe(text)[0] | |
| return result["translation_text"] | |
| if task == "Translation RU → EN": | |
| if not text: | |
| raise gr.Error("Введите русский текст для перевода.") | |
| result = translation_ru_en_pipe(text)[0] | |
| return result["translation_text"] | |
| if task == "Question Answering (QA)": | |
| if not context or not question: | |
| raise gr.Error("Provide both context and question for QA.") | |
| result = qa_pipe(question=question, context=context) | |
| return f"Answer: {result.get('answer', '')}\nScore: {result.get('score', 0.0):.4f}" | |
| raise gr.Error("Unknown task") | |
| def build_nlp_tab() -> None: | |
| with gr.TabItem("7. Natural Language Processing"): | |
| gr.Markdown( | |
| """ | |
| ### 📚 Natural Language Processing | |
| Доступные задачи: | |
| - **Sentiment Analysis** (EN) | |
| - **Summarization** (EN) | |
| - **Translation EN → RU / RU → EN** | |
| - **Question Answering (QA)** | |
| """ | |
| ) | |
| task_selector = gr.Radio( | |
| choices=[ | |
| "Sentiment Analysis", | |
| "Summarization", | |
| "Translation EN → RU", | |
| "Translation RU → EN", | |
| "Question Answering (QA)", | |
| ], | |
| value="Sentiment Analysis", | |
| label="Task", | |
| ) | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Text (for sentiment / summarization / translation)", | |
| lines=6, | |
| placeholder="Введите или вставьте текст здесь...", | |
| ) | |
| with gr.Accordion("Context & Question (for QA)", open=False): | |
| context_input = gr.Textbox( | |
| label="Context", | |
| lines=6, | |
| placeholder="Текст-контекст для вопроса...", | |
| ) | |
| question_input = gr.Textbox( | |
| label="Question", | |
| lines=2, | |
| placeholder="Ваш вопрос по контексту...", | |
| ) | |
| out = gr.Textbox(label="Result", lines=8) | |
| run = gr.Button("Run") | |
| run.click( | |
| fn=nlp_run, | |
| inputs=[task_selector, text_input, context_input, question_input], | |
| outputs=out, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Build full app | |
| # ------------------------------------------------------------------------- | |
| CSS = """ | |
| #root .gradio-container { | |
| max-width: 1100px; | |
| margin: auto; | |
| } | |
| """ | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks(css=CSS, title="Multimodal Playground – Vision, Audio & NLP") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎛️ Multimodal Playground — Vision, Audio & NLP | |
| Демонстрация нескольких open-source моделей: | |
| * Обнаружение объектов и сегментация изображений | |
| * Image–Text Retrieval и Captioning | |
| * Speech Recognition (RU/EN) и Text-to-Speech (RU/EN) | |
| * Классические NLP-задачи | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| build_object_detection_tab() | |
| build_segmentation_tab() | |
| build_retrieval_tab() | |
| build_captioning_tab() | |
| build_speech_tab() | |
| build_tts_tab() | |
| build_nlp_tab() | |
| return demo | |
| app = build_app() | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |