Spaces:
Runtime error
Runtime error
| """Module which defines the code for the "Manage models" tab.""" | |
| from collections.abc import Sequence | |
| from functools import partial | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| # Function to search for RVC models on Hugging Face | |
| def search_rvc_models(query): | |
| url = f"https://huggingface.co/api/models?search={query}&library=rvc" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| models = response.json() | |
| # Create a DataFrame to store the results | |
| df = pd.DataFrame(models) | |
| # Filter the DataFrame to only include the desired columns | |
| df = df[["id", "likes", "downloads"]] | |
| # Add a new column for the download URL | |
| df["downloadUrl"] = "https://huggingface.co/" + df["id"] | |
| # Sort the DataFrame by downloads in descending order | |
| df = df.sort_values(by="downloads", ascending=False) | |
| return df | |
| else: | |
| return pd.DataFrame({"id": ["No models found"]}) | |
| from ultimate_rvc.core.manage.models import ( | |
| delete_all_models, | |
| delete_models, | |
| download_model, | |
| filter_public_models_table, | |
| get_public_model_tags, | |
| get_saved_model_names, | |
| upload_model, | |
| ) | |
| from ultimate_rvc.web.common import ( | |
| PROGRESS_BAR, | |
| confirm_box_js, | |
| confirmation_harness, | |
| exception_harness, | |
| render_msg, | |
| update_dropdowns, | |
| ) | |
| from ultimate_rvc.web.typing_extra import DropdownValue | |
| def _update_models( | |
| num_components: int, | |
| value: DropdownValue = None, | |
| value_indices: Sequence[int] = [], | |
| ) -> gr.Dropdown | tuple[gr.Dropdown, ...]: | |
| """ | |
| Update the choices of one or more dropdown components to the set of | |
| currently saved voice models. | |
| Optionally updates the default value of one or more of these | |
| components. | |
| Parameters | |
| ---------- | |
| num_components : int | |
| Number of dropdown components to update. | |
| value : DropdownValue, optional | |
| New value for dropdown components. | |
| value_indices : Sequence[int], default=[] | |
| Indices of dropdown components to update the value for. | |
| Returns | |
| ------- | |
| gr.Dropdown | tuple[gr.Dropdown, ...] | |
| Updated dropdown component or components. | |
| """ | |
| return update_dropdowns(get_saved_model_names, num_components, value, value_indices) | |
| def _filter_public_models_table(tags: Sequence[str], query: str) -> gr.Dataframe: | |
| """ | |
| Filter table containing metadata of public voice models by tags and | |
| a search query. | |
| Parameters | |
| ---------- | |
| tags : Sequence[str] | |
| Tags to filter the metadata table by. | |
| query : str | |
| Search query to filter the metadata table by. | |
| Returns | |
| ------- | |
| gr.Dataframe | |
| The filtered table rendered in a Gradio dataframe. | |
| """ | |
| models_table = filter_public_models_table(tags, query) | |
| return gr.Dataframe(value=models_table) | |
| def _autofill_model_name_and_url( | |
| public_models_table: pd.DataFrame, | |
| select_event: gr.SelectData, | |
| ) -> tuple[gr.Textbox, gr.Textbox]: | |
| """ | |
| Autofill two textboxes with respectively the name and URL that is | |
| saved in the currently selected row of the public models table. | |
| Parameters | |
| ---------- | |
| public_models_table : pd.DataFrame | |
| The public models table saved in a Pandas dataframe. | |
| select_event : gr.SelectData | |
| Event containing the index of the currently selected row in the | |
| public models table. | |
| Returns | |
| ------- | |
| name : gr.Textbox | |
| The textbox containing the model name. | |
| url : gr.Textbox | |
| The textbox containing the model URL. | |
| Raises | |
| ------ | |
| TypeError | |
| If the index in the provided event is not a sequence. | |
| """ | |
| event_index = select_event.index | |
| if not isinstance(event_index, Sequence): | |
| err_msg = ( | |
| f"Expected a sequence of indices but got {type(event_index)} from the" | |
| " provided event." | |
| ) | |
| raise TypeError(err_msg) | |
| event_index = event_index[0] | |
| url = public_models_table.loc[event_index, "URL"] | |
| name = public_models_table.loc[event_index, "Name"] | |
| if isinstance(url, str) and isinstance(name, str): | |
| return gr.Textbox(value=name), gr.Textbox(value=url) | |
| err_msg = ( | |
| "Expected model name and URL to be strings but got" | |
| f" {type(name)} and {type(url)} respectively." | |
| ) | |
| raise TypeError(err_msg) | |
| def render( | |
| model_delete: gr.Dropdown, | |
| model_1click: gr.Dropdown, | |
| model_multi: gr.Dropdown, | |
| ) -> None: | |
| """ | |
| Render "Manage models" tab. | |
| Parameters | |
| ---------- | |
| model_delete : gr.Dropdown | |
| Dropdown for selecting voice models to delete in the | |
| "Delete models" tab. | |
| model_1click : gr.Dropdown | |
| Dropdown for selecting a voice model to use in the | |
| "One-click generation" tab. | |
| model_multi : gr.Dropdown | |
| Dropdown for selecting a voice model to use in the | |
| "Multi-step generation" tab. | |
| """ | |
| # Download tab | |
| dummy_checkbox = gr.Checkbox(visible=False) | |
| with gr.Tab("Download model"): | |
| with gr.Accordion("View public models table", open=False): | |
| gr.Markdown("") | |
| gr.Markdown("*HOW TO USE*") | |
| gr.Markdown( | |
| "- Filter voice models by selecting one or more tags and/or providing a" | |
| " search query.", | |
| ) | |
| gr.Markdown( | |
| "- Select a row in the table to autofill the name and" | |
| " URL for the given voice model in the form fields below.", | |
| ) | |
| gr.Markdown("") | |
| with gr.Row(): | |
| search_query = gr.Textbox(label="Search query") | |
| tags = gr.CheckboxGroup( | |
| value=[], | |
| label="Tags", | |
| choices=get_public_model_tags(), | |
| ) | |
| with gr.Row(): | |
| public_models_table = gr.Dataframe( | |
| value=_filter_public_models_table, | |
| inputs=[tags, search_query], | |
| headers=["Name", "Description", "Tags", "Credit", "Added", "URL"], | |
| label="Public models table", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| model_url = gr.Textbox( | |
| label="Model URL", | |
| info=( | |
| "Should point to a zip file containing a .pth model file and" | |
| " optionally also an .index file." | |
| ), | |
| ) | |
| model_name = gr.Textbox( | |
| label="Model name", | |
| info="Enter a unique name for the voice model.", | |
| ) | |
| with gr.Row(equal_height=True): | |
| download_btn = gr.Button("Download 🌐", variant="primary", scale=19) | |
| download_msg = gr.Textbox( | |
| label="Output message", | |
| interactive=False, | |
| scale=20, | |
| ) | |
| public_models_table.select( | |
| _autofill_model_name_and_url, | |
| inputs=public_models_table, | |
| outputs=[model_name, model_url], | |
| show_progress="hidden", | |
| ) | |
| download_btn_click = download_btn.click( | |
| partial( | |
| exception_harness(download_model), | |
| progress_bar=PROGRESS_BAR, | |
| ), | |
| inputs=[model_url, model_name], | |
| outputs=download_msg, | |
| ).success( | |
| partial( | |
| render_msg, | |
| "[+] Succesfully downloaded voice model!", | |
| ), | |
| inputs=model_name, | |
| outputs=download_msg, | |
| show_progress="hidden", | |
| ) | |
| # Upload tab | |
| with gr.Tab("Upload model"): | |
| with gr.Accordion("HOW TO USE"): | |
| gr.Markdown("") | |
| gr.Markdown( | |
| "1. Find the .pth file for a locally trained RVC model (e.g. in your" | |
| " local weights folder) and optionally also a corresponding .index file" | |
| " (e.g. in your logs/[name] folder)", | |
| ) | |
| gr.Markdown( | |
| "2. Upload the files directly or save them to a folder, then compress" | |
| " that folder and upload the resulting .zip file", | |
| ) | |
| gr.Markdown("3. Enter a unique name for the uploaded model") | |
| gr.Markdown("4. Click 'Upload'") | |
| with gr.Row(): | |
| model_files = gr.File( | |
| label="Files", | |
| file_count="multiple", | |
| file_types=[".zip", ".pth", ".index"], | |
| ) | |
| local_model_name = gr.Textbox(label="Model name") | |
| with gr.Row(equal_height=True): | |
| upload_btn = gr.Button("Upload", variant="primary", scale=19) | |
| upload_msg = gr.Textbox( | |
| label="Output message", | |
| interactive=False, | |
| scale=20, | |
| ) | |
| upload_btn_click = upload_btn.click( | |
| partial(exception_harness(upload_model), progress_bar=PROGRESS_BAR), | |
| inputs=[model_files, local_model_name], | |
| outputs=upload_msg, | |
| ).success( | |
| partial( | |
| render_msg, | |
| "[+] Successfully uploaded voice model!", | |
| ), | |
| inputs=local_model_name, | |
| outputs=upload_msg, | |
| show_progress="hidden", | |
| ) | |
| with gr.Tab("Delete models"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_delete.render() | |
| delete_btn = gr.Button("Delete selected", variant="secondary") | |
| delete_all_btn = gr.Button("Delete all", variant="primary") | |
| with gr.Column(): | |
| delete_msg = gr.Textbox(label="Output message", interactive=False) | |
| delete_btn_click = delete_btn.click( | |
| partial(confirmation_harness(delete_models), progress_bar=PROGRESS_BAR), | |
| inputs=[dummy_checkbox, model_delete], | |
| outputs=delete_msg, | |
| js=confirm_box_js( | |
| "Are you sure you want to delete the selected voice models?", | |
| ), | |
| ).success( | |
| partial(render_msg, "[-] Successfully deleted selected voice models!"), | |
| outputs=delete_msg, | |
| show_progress="hidden", | |
| ) | |
| delete_all_btn_click = delete_all_btn.click( | |
| partial( | |
| confirmation_harness(delete_all_models), | |
| progress_bar=PROGRESS_BAR, | |
| ), | |
| inputs=dummy_checkbox, | |
| outputs=delete_msg, | |
| js=confirm_box_js("Are you sure you want to delete all voice models?"), | |
| ).success( | |
| partial(render_msg, "[-] Successfully deleted all voice models!"), | |
| outputs=delete_msg, | |
| show_progress="hidden", | |
| ) | |
| with gr.Tab("Search models"): | |
| # Textbox for user to enter search query | |
| query = gr.Textbox(label="Search for RVC models", placeholder="Enter your search query here") | |
| # Button to trigger the search | |
| search_button = gr.Button("Search") | |
| # Output for displaying the search results as a DataFrame | |
| results = gr.Dataframe(label="Search Results") | |
| # Event listener for the search button | |
| search_button.click(fn=search_rvc_models, inputs=query, outputs=results) | |
| for click_event in [ | |
| download_btn_click, | |
| upload_btn_click, | |
| delete_btn_click, | |
| delete_all_btn_click, | |
| ]: | |
| click_event.success( | |
| partial(_update_models, 3, [], [2]), | |
| outputs=[model_1click, model_multi, model_delete], | |
| show_progress="hidden", | |
| ) |