| | |
| | import gradio as gr |
| | import torch |
| | import torchaudio |
| | from transformers import ( |
| | pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, |
| | AutoImageProcessor, AutoModelForObjectDetection, |
| | BlipForQuestionAnswering, BlipProcessor, CLIPModel, CLIPProcessor, |
| | VitsModel, AutoTokenizer |
| | ) |
| | from PIL import Image, ImageDraw |
| | import requests |
| | import numpy as np |
| | import soundfile as sf |
| | from gtts import gTTS |
| | import tempfile |
| | import os |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | |
| | models = {} |
| |
|
| | def load_audio_model(model_name): |
| | if model_name not in models: |
| | if model_name == "whisper": |
| | models[model_name] = pipeline( |
| | "automatic-speech-recognition", |
| | model="openai/whisper-small" |
| | ) |
| | elif model_name == "wav2vec2": |
| | models[model_name] = pipeline( |
| | "automatic-speech-recognition", |
| | model="bond005/wav2vec2-large-ru-golos" |
| | ) |
| | elif model_name == "audio_classifier": |
| | models[model_name] = pipeline( |
| | "audio-classification", |
| | model="MIT/ast-finetuned-audioset-10-10-0.4593" |
| | ) |
| | elif model_name == "emotion_classifier": |
| | models[model_name] = pipeline( |
| | "audio-classification", |
| | model="superb/hubert-large-superb-er" |
| | ) |
| | return models[model_name] |
| |
|
| | def load_image_model(model_name): |
| | if model_name not in models: |
| | if model_name == "object_detection": |
| | models[model_name] = pipeline("object-detection", model="facebook/detr-resnet-50") |
| | elif model_name == "segmentation": |
| | models[model_name] = pipeline("image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512") |
| | elif model_name == "captioning": |
| | models[model_name] = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") |
| | elif model_name == "vqa": |
| | models[model_name] = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa") |
| | elif model_name == "clip": |
| | models[model_name] = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| | models[f"{model_name}_processor"] = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| | return models[model_name] |
| |
|
| | |
| | def audio_classification(audio_file, model_type): |
| | classifier = load_audio_model(model_type) |
| | results = classifier(audio_file) |
| | |
| | output = "Топ-5 предсказаний:\n" |
| | for i, result in enumerate(results[:5]): |
| | output += f"{i+1}. {result['label']}: {result['score']:.4f}\n" |
| | |
| | return output |
| |
|
| | def speech_recognition(audio_file, model_type): |
| | asr_pipeline = load_audio_model(model_type) |
| | |
| | if model_type == "whisper": |
| | result = asr_pipeline(audio_file, generate_kwargs={"language": "russian"}) |
| | else: |
| | result = asr_pipeline(audio_file) |
| | |
| | return result['text'] |
| |
|
| | def text_to_speech(text, model_type): |
| | if model_type == "silero": |
| | |
| | model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models', |
| | model='silero_tts', |
| | language='ru', |
| | speaker='ru_v3') |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| | model.save_wav(text=text, speaker='aidar', sample_rate=48000, audio_path=f.name) |
| | return f.name |
| | |
| | elif model_type == "gtts": |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| | tts = gTTS(text=text, lang='ru') |
| | tts.save(f.name) |
| | return f.name |
| | |
| | elif model_type == "mms": |
| | |
| | model = VitsModel.from_pretrained("facebook/mms-tts-rus") |
| | tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") |
| | |
| | inputs = tokenizer(text, return_tensors="pt") |
| | with torch.no_grad(): |
| | output = model(**inputs).waveform |
| | |
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| | sf.write(f.name, output.numpy().squeeze(), model.config.sampling_rate) |
| | return f.name |
| |
|
| | |
| | def object_detection(image): |
| | detector = load_image_model("object_detection") |
| | results = detector(image) |
| | |
| | |
| | draw = ImageDraw.Draw(image) |
| | for result in results: |
| | box = result['box'] |
| | label = result['label'] |
| | score = result['score'] |
| | |
| | draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], |
| | outline='red', width=3) |
| | draw.text((box['xmin'], box['ymin']), |
| | f"{label}: {score:.2f}", fill='red') |
| | |
| | return image |
| |
|
| | def image_segmentation(image): |
| | segmenter = load_image_model("segmentation") |
| | results = segmenter(image) |
| | |
| | |
| | return results[0]['mask'] |
| |
|
| | def image_captioning(image): |
| | captioner = load_image_model("captioning") |
| | result = captioner(image) |
| | return result[0]['generated_text'] |
| |
|
| | def visual_question_answering(image, question): |
| | vqa_pipeline = load_image_model("vqa") |
| | result = vqa_pipeline(image, question) |
| | return f"{result[0]['answer']} (confidence: {result[0]['score']:.3f})" |
| |
|
| | def zero_shot_classification(image, classes): |
| | model = load_image_model("clip") |
| | processor = models["clip_processor"] |
| | |
| | class_list = [cls.strip() for cls in classes.split(",")] |
| | |
| | inputs = processor(text=class_list, images=image, return_tensors="pt", padding=True) |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits_per_image = outputs.logits_per_image |
| | probs = logits_per_image.softmax(dim=1) |
| | |
| | result = "Zero-Shot Classification Results:\n" |
| | for i, cls in enumerate(class_list): |
| | result += f"{cls}: {probs[0][i].item():.4f}\n" |
| | |
| | return result |
| |
|
| | def image_retrieval(images, query): |
| | if not images or not query: |
| | return "Пожалуйста, загрузите изображения и введите запрос" |
| | |
| | |
| | model = load_image_model("clip") |
| | processor = models["clip_processor"] |
| | |
| | |
| | image_inputs = processor(images=images, return_tensors="pt", padding=True) |
| | with torch.no_grad(): |
| | image_embeddings = model.get_image_features(**image_inputs) |
| | image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) |
| | |
| | |
| | text_inputs = processor(text=[query], return_tensors="pt", padding=True) |
| | with torch.no_grad(): |
| | text_embeddings = model.get_text_features(**text_inputs) |
| | text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) |
| | |
| | |
| | similarities = (image_embeddings @ text_embeddings.T) |
| | |
| | |
| | best_idx = similarities.argmax().item() |
| | best_score = similarities[best_idx].item() |
| | |
| | return f"Лучшее изображение: #{best_idx + 1} (схожесть: {best_score:.4f})", images[best_idx] |
| |
|
| | |
| | with gr.Blocks(title="Multimodal AI Demo", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 🎯 Мультимодальные AI модели") |
| | gr.Markdown("Демонстрация различных задач компьютерного зрения и обработки звука с использованием Hugging Face Transformers") |
| | |
| | with gr.Tab("🎵 Классификация аудио"): |
| | gr.Markdown("## Zero-Shot Audio Classification") |
| | with gr.Row(): |
| | with gr.Column(): |
| | audio_input = gr.Audio(label="Загрузите аудиофайл", type="filepath") |
| | audio_model_dropdown = gr.Dropdown( |
| | choices=["audio_classifier", "emotion_classifier"], |
| | label="Выберите модель", |
| | value="audio_classifier", |
| | info="audio_classifier - общая классификация, emotion_classifier - эмоции в речи" |
| | ) |
| | classify_btn = gr.Button("Классифицировать") |
| | with gr.Column(): |
| | audio_output = gr.Textbox(label="Результаты классификации", lines=10) |
| | |
| | classify_btn.click( |
| | fn=audio_classification, |
| | inputs=[audio_input, audio_model_dropdown], |
| | outputs=audio_output |
| | ) |
| | |
| | with gr.Tab("🗣️ Распознавание речи"): |
| | gr.Markdown("## Automatic Speech Recognition (ASR)") |
| | with gr.Row(): |
| | with gr.Column(): |
| | asr_audio_input = gr.Audio(label="Загрузите аудио с речью", type="filepath") |
| | asr_model_dropdown = gr.Dropdown( |
| | choices=["whisper", "wav2vec2"], |
| | label="Выберите модель", |
| | value="whisper", |
| | info="whisper - многоязычная, wav2vec2 - специализированная для русского" |
| | ) |
| | transcribe_btn = gr.Button("Транскрибировать") |
| | with gr.Column(): |
| | asr_output = gr.Textbox(label="Транскрипция", lines=5) |
| | |
| | transcribe_btn.click( |
| | fn=speech_recognition, |
| | inputs=[asr_audio_input, asr_model_dropdown], |
| | outputs=asr_output |
| | ) |
| | |
| | with gr.Tab("🔊 Синтез речи"): |
| | gr.Markdown("## Text-to-Speech (TTS)") |
| | with gr.Row(): |
| | with gr.Column(): |
| | tts_text_input = gr.Textbox( |
| | label="Введите текст для синтеза", |
| | placeholder="Введите текст на русском языке...", |
| | lines=3 |
| | ) |
| | tts_model_dropdown = gr.Dropdown( |
| | choices=["silero", "gtts", "mms"], |
| | label="Выберите модель", |
| | value="silero", |
| | info="silero - высокое качество, gtts - Google TTS, mms - Facebook MMS" |
| | ) |
| | synthesize_btn = gr.Button("Синтезировать речь") |
| | with gr.Column(): |
| | tts_output = gr.Audio(label="Синтезированная речь") |
| | |
| | synthesize_btn.click( |
| | fn=text_to_speech, |
| | inputs=[tts_text_input, tts_model_dropdown], |
| | outputs=tts_output |
| | ) |
| | |
| | with gr.Tab("📦 Детекция объектов"): |
| | gr.Markdown("## Object Detection") |
| | with gr.Row(): |
| | with gr.Column(): |
| | obj_detection_input = gr.Image(label="Загрузите изображение", type="pil") |
| | detect_btn = gr.Button("Обнаружить объекты") |
| | with gr.Column(): |
| | obj_detection_output = gr.Image(label="Результат детекции") |
| | |
| | detect_btn.click( |
| | fn=object_detection, |
| | inputs=obj_detection_input, |
| | outputs=obj_detection_output |
| | ) |
| | |
| | with gr.Tab("🎨 Сегментация"): |
| | gr.Markdown("## Image Segmentation") |
| | with gr.Row(): |
| | with gr.Column(): |
| | seg_input = gr.Image(label="Загрузите изображение", type="pil") |
| | segment_btn = gr.Button("Сегментировать") |
| | with gr.Column(): |
| | seg_output = gr.Image(label="Маска сегментации") |
| | |
| | segment_btn.click( |
| | fn=image_segmentation, |
| | inputs=seg_input, |
| | outputs=seg_output |
| | ) |
| | |
| | with gr.Tab("📝 Описание изображений"): |
| | gr.Markdown("## Image Captioning") |
| | with gr.Row(): |
| | with gr.Column(): |
| | caption_input = gr.Image(label="Загрузите изображение", type="pil") |
| | caption_btn = gr.Button("Сгенерировать описание") |
| | with gr.Column(): |
| | caption_output = gr.Textbox(label="Описание изображения", lines=3) |
| | |
| | caption_btn.click( |
| | fn=image_captioning, |
| | inputs=caption_input, |
| | outputs=caption_output |
| | ) |
| | |
| | with gr.Tab("❓ Визуальные вопросы"): |
| | gr.Markdown("## Visual Question Answering") |
| | with gr.Row(): |
| | with gr.Column(): |
| | vqa_image_input = gr.Image(label="Загрузите изображение", type="pil") |
| | vqa_question_input = gr.Textbox( |
| | label="Вопрос об изображении", |
| | placeholder="Что происходит на этом изображении?", |
| | lines=2 |
| | ) |
| | vqa_btn = gr.Button("Ответить на вопрос") |
| | with gr.Column(): |
| | vqa_output = gr.Textbox(label="Ответ", lines=3) |
| | |
| | vqa_btn.click( |
| | fn=visual_question_answering, |
| | inputs=[vqa_image_input, vqa_question_input], |
| | outputs=vqa_output |
| | ) |
| | |
| | with gr.Tab("🎯 Zero-Shot классификация"): |
| | gr.Markdown("## Zero-Shot Image Classification") |
| | with gr.Row(): |
| | with gr.Column(): |
| | zs_image_input = gr.Image(label="Загрузите изображение", type="pil") |
| | zs_classes_input = gr.Textbox( |
| | label="Классы для классификации (через запятую)", |
| | placeholder="человек, машина, дерево, здание, животное", |
| | lines=2 |
| | ) |
| | zs_classify_btn = gr.Button("Классифицировать") |
| | with gr.Column(): |
| | zs_output = gr.Textbox(label="Результаты классификации", lines=10) |
| | |
| | zs_classify_btn.click( |
| | fn=zero_shot_classification, |
| | inputs=[zs_image_input, zs_classes_input], |
| | outputs=zs_output |
| | ) |
| | |
| | with gr.Tab("🔍 Поиск изображений"): |
| | gr.Markdown("## Image Retrieval") |
| | with gr.Row(): |
| | with gr.Column(): |
| | retrieval_images_input = gr.Gallery( |
| | label="Загрузите изображения для поиска", |
| | type="pil" |
| | ) |
| | retrieval_query_input = gr.Textbox( |
| | label="Текстовый запрос", |
| | placeholder="описание того, что вы ищете...", |
| | lines=2 |
| | ) |
| | retrieval_btn = gr.Button("Найти изображение") |
| | with gr.Column(): |
| | retrieval_output_text = gr.Textbox(label="Результат поиска") |
| | retrieval_output_image = gr.Image(label="Найденное изображение") |
| | |
| | retrieval_btn.click( |
| | fn=image_retrieval, |
| | inputs=[retrieval_images_input, retrieval_query_input], |
| | outputs=[retrieval_output_text, retrieval_output_image] |
| | ) |
| | |
| | gr.Markdown("---") |
| | gr.Markdown("### 📊 Поддерживаемые задачи:") |
| | gr.Markdown(""" |
| | - **🎵 Аудио**: Классификация, распознавание речи, синтез речи |
| | - **👁️ Компьютерное зрение**: Детекция объектов, сегментация, описание изображений |
| | - **🤖 Мультимодальные**: Визуальные вопросы, zero-shot классификация, поиск по изображениям |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |