Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import inspect | |
| from dataclasses import dataclass | |
| import re | |
| class RegisterTables: | |
| model_classes = {} | |
| frontend_classes = {} | |
| specaug_classes = {} | |
| normalize_classes = {} | |
| encoder_classes = {} | |
| decoder_classes = {} | |
| joint_network_classes = {} | |
| predictor_classes = {} | |
| stride_conv_classes = {} | |
| tokenizer_classes = {} | |
| batch_sampler_classes = {} | |
| dataset_classes = {} | |
| index_ds_classes = {} | |
| def print(self, key=None): | |
| print("\ntables: \n") | |
| fields = vars(self) | |
| for classes_key, classes_dict in fields.items(): | |
| flag = True | |
| if key is not None: | |
| flag = key in classes_key | |
| if classes_key.endswith("_meta") and flag: | |
| print( | |
| f"----------- ** {classes_key.replace('_meta', '')} ** --------------" | |
| ) | |
| headers = ["class name", "class location"] | |
| metas = [] | |
| for register_key, meta in classes_dict.items(): | |
| metas.append(meta) | |
| metas.sort(key=lambda x: x[0]) | |
| data = [headers] + metas | |
| col_widths = [max(len(str(item)) for item in col) for col in zip(*data)] | |
| for row in data: | |
| print( | |
| "| " | |
| + " | ".join( | |
| str(item).ljust(width) | |
| for item, width in zip(row, col_widths) | |
| ) | |
| + " |" | |
| ) | |
| print("\n") | |
| def register(self, register_tables_key: str, key=None): | |
| def decorator(target_class): | |
| if not hasattr(self, register_tables_key): | |
| setattr(self, register_tables_key, {}) | |
| logging.info( | |
| "new registry table has been added: {}".format(register_tables_key) | |
| ) | |
| registry = getattr(self, register_tables_key) | |
| registry_key = key if key is not None else target_class.__name__ | |
| # assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format( | |
| # registry_key, target_class, register_tables_key) | |
| registry[registry_key] = target_class | |
| # meta, headers = ["class name", "register name", "class location"] | |
| register_tables_key_meta = register_tables_key + "_meta" | |
| if not hasattr(self, register_tables_key_meta): | |
| setattr(self, register_tables_key_meta, {}) | |
| registry_meta = getattr(self, register_tables_key_meta) | |
| # doc = target_class.__doc__ | |
| class_file = inspect.getfile(target_class) | |
| class_line = inspect.getsourcelines(target_class)[1] | |
| pattern = r"^.+/funasr/" | |
| class_file = re.sub(pattern, "funasr/", class_file) | |
| meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"] | |
| # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"] | |
| registry_meta[registry_key] = meata_data | |
| # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}") | |
| return target_class | |
| return decorator | |
| tables = RegisterTables() | |
| # from funasr_detach.models.paraformer_streaming.model import ParaformerStreaming | |
| import funasr | |