Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import gc | |
| import json | |
| from pipeline_tabs.text_tab import text_tab | |
| from pipeline_tabs.diffusion_tab import diffusion_tab | |
| model_cache = {} | |
| def unload_all_models(): | |
| model_cache.clear() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| with gr.Blocks() as demo: | |
| with gr.Tabs(): | |
| text_tab(model_cache, unload_all_models) | |
| diffusion_tab(model_cache, unload_all_models) | |
| # Shared state display | |
| def pretty_json(): | |
| return json.dumps(list(model_cache.keys()), indent=2, ensure_ascii=False) | |
| state_box = gr.Textbox(label="Loaded Models", lines=4, interactive=False, value=pretty_json()) | |
| # Update state_box whenever a model is loaded/unloaded | |
| demo.load(fn=pretty_json, inputs=None, outputs=state_box) | |
| # Optionally, you can add a button to refresh the state display | |
| refresh_btn = gr.Button("Refresh Model State") | |
| refresh_btn.click(fn=pretty_json, inputs=None, outputs=state_box) | |
| demo.launch() |