Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| import gc | |
| import json | |
| # Define available models/tasks | |
| MODEL_CONFIGS = [ | |
| { | |
| "name": "Text Generation (GPT-2)", | |
| "task": "text-generation", | |
| "model": "gpt2", | |
| "input_type": "text", | |
| "output_type": "text" | |
| }, | |
| { | |
| "name": "Image Classification (ViT)", | |
| "task": "image-classification", | |
| "model": "google/vit-base-patch16-224", | |
| "input_type": "image", | |
| "output_type": "label" | |
| }, | |
| # Add more models/tasks as needed | |
| ] | |
| # Shared state for demo | |
| shared_state = gr.State({"active_model": None, "last_result": None}) | |
| # Model cache for lazy loading | |
| model_cache = {} | |
| def load_model(task, model_name): | |
| # Use device_map="auto" or device=0 for GPU if available | |
| return pipeline(task, model=model_name, device=-1) | |
| def unload_model(model_key): | |
| if model_key in model_cache: | |
| del model_cache[model_key] | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multi-Model, Multi-Task Gradio Demo\n_Switch between models and tasks in one Space!_") | |
| tab_names = [m["name"] for m in MODEL_CONFIGS] | |
| with gr.Tabs() as tabs: | |
| tab_blocks = [] | |
| for i, config in enumerate(MODEL_CONFIGS): | |
| with gr.Tab(config["name"]): | |
| status = gr.Markdown(f"**Model:** {config['model']}<br>**Task:** {config['task']}") | |
| load_btn = gr.Button("Load Model") | |
| unload_btn = gr.Button("Unload Model") | |
| if config["input_type"] == "text": | |
| input_comp = gr.Textbox(label="Input Text") | |
| elif config["input_type"] == "image": | |
| input_comp = gr.Image(label="Input Image") | |
| else: | |
| input_comp = gr.Textbox(label="Input") | |
| run_btn = gr.Button("Run Model") | |
| output_comp = gr.Textbox(label="Output", lines=4) | |
| model_key = f"{config['task']}|{config['model']}" | |
| def do_load(state): | |
| if model_key not in model_cache: | |
| model_cache[model_key] = load_model(config["task"], config["model"]) | |
| state = dict(state) | |
| state["active_model"] = model_key | |
| return f"Loaded: {model_key}", state | |
| def do_unload(state): | |
| unload_model(model_key) | |
| state = dict(state) | |
| state["active_model"] = None | |
| return f"Unloaded: {model_key}", state | |
| def do_run(inp, state): | |
| if model_key not in model_cache: | |
| return "Model not loaded!", state | |
| pipe = model_cache[model_key] | |
| result = pipe(inp) | |
| state = dict(state) | |
| state["last_result"] = result | |
| return str(result), state | |
| load_btn.click(do_load, shared_state, [status, shared_state]) | |
| unload_btn.click(do_unload, shared_state, [status, shared_state]) | |
| run_btn.click(do_run, [input_comp, shared_state], [output_comp, shared_state]) | |
| # Shared state display | |
| def pretty_json(state): | |
| return json.dumps(state, indent=2, ensure_ascii=False) | |
| shared_state_box = gr.Textbox(label="Shared State", lines=8, interactive=False) | |
| shared_state.change(pretty_json, shared_state, shared_state_box) | |
| demo.launch() |