Spaces:
Sleeping
Sleeping
| """ | |
| Gradio interface for converting models. | |
| """ | |
| import os | |
| import uuid | |
| import re | |
| import subprocess | |
| import gradio as gr | |
| from demo import constants, utils | |
| from lczerolens import backends | |
| def get_models_info(onnx=True, leela=True): | |
| """ | |
| Get the names of the models in the model directory. | |
| """ | |
| model_df = [] | |
| exp = r"(?P<n_filters>\d+)x(?P<n_blocks>\d+)" | |
| if onnx: | |
| for filename in os.listdir(constants.ONNX_MODEL_DIRECTORY): | |
| if filename.endswith(".onnx"): | |
| match = re.search(exp, filename) | |
| if match is None: | |
| n_filters = -1 | |
| n_blocks = -1 | |
| else: | |
| n_filters = int(match.group("n_filters")) | |
| n_blocks = int(match.group("n_blocks")) | |
| model_df.append( | |
| [ | |
| filename, | |
| "ONNX", | |
| n_blocks, | |
| n_filters, | |
| ] | |
| ) | |
| if leela: | |
| for filename in os.listdir(constants.LEELA_MODEL_DIRECTORY): | |
| if filename.endswith(".pb.gz"): | |
| match = re.search(exp, filename) | |
| if match is None: | |
| n_filters = -1 | |
| n_blocks = -1 | |
| else: | |
| n_filters = int(match.group("n_filters")) | |
| n_blocks = int(match.group("n_blocks")) | |
| model_df.append( | |
| [ | |
| filename, | |
| "LEELA", | |
| n_blocks, | |
| n_filters, | |
| ] | |
| ) | |
| return model_df | |
| def save_model(tmp_file_path): | |
| """ | |
| Save the model to the model directory. | |
| """ | |
| popen = subprocess.Popen( | |
| ["file", tmp_file_path], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| popen.wait() | |
| if popen.returncode != 0: | |
| raise RuntimeError | |
| file_desc = popen.stdout.read().decode("utf-8").split(tmp_file_path)[1].strip() | |
| rename_match = re.search(r"was\s\"(?P<name>.+)\"", file_desc) | |
| type_match = re.search(r"\:\s(?P<type>[a-zA-Z]+)", file_desc) | |
| if rename_match is None or type_match is None: | |
| raise RuntimeError | |
| model_name = rename_match.group("name") | |
| model_type = type_match.group("type") | |
| if model_type != "gzip": | |
| raise RuntimeError | |
| os.rename( | |
| tmp_file_path, | |
| f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz", | |
| ) | |
| try: | |
| backends.describenet( | |
| f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz", | |
| ) | |
| except RuntimeError: | |
| os.remove(f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}.gz") | |
| raise RuntimeError | |
| def list_models(): | |
| """ | |
| List the models in the model directory. | |
| """ | |
| models_info = get_models_info() | |
| return sorted([[model_info[0]] for model_info in models_info]) | |
| def on_select_model_df( | |
| evt: gr.SelectData, | |
| ): | |
| """ | |
| When a model is selected, update the statement. | |
| """ | |
| return evt.value | |
| def convert_model( | |
| model_name: str, | |
| ): | |
| """ | |
| Convert the model. | |
| """ | |
| if model_name == "": | |
| gr.Warning( | |
| "Please select a model.", | |
| ) | |
| return list_models(), "" | |
| if model_name.endswith(".onnx"): | |
| gr.Warning( | |
| "ONNX conversion not implemented.", | |
| ) | |
| return list_models(), "" | |
| try: | |
| backends.convert_to_onnx( | |
| f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", | |
| f"{constants.ONNX_MODEL_DIRECTORY}/{model_name[:-6]}.onnx", | |
| ) | |
| except RuntimeError: | |
| gr.Warning( | |
| f"Could not convert net at `{model_name}`.", | |
| ) | |
| return list_models(), "Conversion failed" | |
| return list_models(), "Conversion successful" | |
| def upload_model( | |
| model_file: gr.File, | |
| ): | |
| """ | |
| Convert the model. | |
| """ | |
| if model_file is None: | |
| gr.Warning( | |
| "File not uploaded.", | |
| ) | |
| return list_models() | |
| try: | |
| id = uuid.uuid4() | |
| tmp_file_path = f"{constants.LEELA_MODEL_DIRECTORY}/{id}" | |
| with open( | |
| tmp_file_path, | |
| "wb", | |
| ) as f: | |
| f.write(model_file) | |
| save_model(tmp_file_path) | |
| except RuntimeError: | |
| gr.Warning( | |
| "Invalid file type.", | |
| ) | |
| finally: | |
| if os.path.exists(tmp_file_path): | |
| os.remove(tmp_file_path) | |
| return list_models() | |
| def get_model_description( | |
| model_name: str, | |
| ): | |
| """ | |
| Get the model description. | |
| """ | |
| if model_name == "": | |
| gr.Warning( | |
| "Please select a model.", | |
| ) | |
| return "" | |
| if model_name.endswith(".onnx"): | |
| gr.Warning( | |
| "ONNX description not implemented.", | |
| ) | |
| return "" | |
| try: | |
| description = backends.describenet( | |
| f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}", | |
| ) | |
| except RuntimeError: | |
| raise gr.Error( | |
| f"Could not describe net at `{model_name}`.", | |
| ) | |
| return description | |
| def get_model_path( | |
| model_name: str, | |
| ): | |
| """ | |
| Get the model path. | |
| """ | |
| if model_name == "": | |
| gr.Warning( | |
| "Please select a model.", | |
| ) | |
| return None | |
| if model_name.endswith(".onnx"): | |
| return f"{constants.ONNX_MODEL_DIRECTORY}/{model_name}" | |
| else: | |
| return f"{constants.LEELA_MODEL_DIRECTORY}/{model_name}" | |
| with gr.Blocks() as interface: | |
| model_file = gr.File(type="binary") | |
| upload_button = gr.Button( | |
| value="Upload", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| model_df = gr.Dataframe( | |
| headers=["Available models"], | |
| datatype=["str"], | |
| interactive=False, | |
| type="array", | |
| value=list_models, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| model_name = gr.Textbox(label="Selected model", lines=1, interactive=False, scale=7) | |
| conversion_status = gr.Textbox( | |
| label="Conversion status", | |
| lines=1, | |
| interactive=False, | |
| ) | |
| convert_button = gr.Button( | |
| value="Convert", | |
| ) | |
| describe_button = gr.Button( | |
| value="Describe model", | |
| ) | |
| model_description = gr.Textbox( | |
| label="Model description", | |
| lines=1, | |
| interactive=False, | |
| ) | |
| download_button = gr.Button( | |
| value="Get download link", | |
| ) | |
| download_file = gr.File( | |
| type="filepath", | |
| label="Download link", | |
| interactive=False, | |
| ) | |
| model_df.select( | |
| on_select_model_df, | |
| None, | |
| model_name, | |
| ) | |
| upload_button.click( | |
| upload_model, | |
| model_file, | |
| model_df, | |
| ) | |
| convert_button.click( | |
| convert_model, | |
| model_name, | |
| [model_df, conversion_status], | |
| ) | |
| describe_button.click( | |
| get_model_description, | |
| model_name, | |
| model_description, | |
| ) | |
| download_button.click( | |
| get_model_path, | |
| model_name, | |
| download_file, | |
| ) | |