Spaces:
Running
on
Zero
Running
on
Zero
| import functools | |
| import importlib | |
| import json | |
| import unittest | |
| import comfy.model_management | |
| import comfy.model_patcher | |
| import comfy.sd | |
| import folder_paths | |
| import torch | |
| from . import patchers, utils | |
| HAS_VELOCATOR = importlib.util.find_spec("xelerate") is not None | |
| def get_quant_inputs(): | |
| return { | |
| "quant_type": ( | |
| [ | |
| "int8_dynamic", | |
| "e4m3_e4m3_dynamic", | |
| "e4m3_e4m3_dynamic_per_tensor", | |
| "int8_weightonly", | |
| "e4m3_weightonly", | |
| "e4m3_e4m3_weightonly", | |
| "e4m3_e4m3_weightonly_per_tensor", | |
| "nf4_weightonly", | |
| "af4_weightonly", | |
| "int4_weightonly", | |
| ], | |
| ), | |
| "filter_fn": ( | |
| "STRING", | |
| { | |
| "default": "fnmatch_matches_fqn", | |
| }, | |
| ), | |
| "filter_fn_kwargs": ( | |
| "STRING", | |
| { | |
| "multiline": True, | |
| "default": '{"pattern": ["*"]}', | |
| }, | |
| ), | |
| "kwargs": ( | |
| "STRING", | |
| { | |
| "multiline": True, | |
| # "default": "{}", | |
| }, | |
| ), | |
| } | |
| class VelocatorLoadAndQuantizeDiffusionModel: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "unet_name": (folder_paths.get_filename_list("diffusion_models"),), | |
| **utils.get_weight_dtype_inputs(), | |
| "lowvram": ("BOOLEAN", {"default": True}), | |
| "full_load": ("BOOLEAN", {"default": True}), | |
| "quantize": ("BOOLEAN", {"default": True}), | |
| "quantize_on_load_device": ("BOOLEAN", {"default": True}), | |
| **get_quant_inputs(), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "load_unet" | |
| CATEGORY = "wavespeed/velocator" | |
| def load_unet( | |
| self, | |
| unet_name, | |
| weight_dtype, | |
| lowvram, | |
| full_load, | |
| quantize, | |
| quantize_on_load_device, | |
| quant_type, | |
| filter_fn, | |
| filter_fn_kwargs, | |
| kwargs, | |
| ): | |
| model_options = {} | |
| if lowvram: | |
| model_options["initial_device"] = torch.device("cpu") | |
| model_options = utils.parse_weight_dtype(model_options, weight_dtype) | |
| unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) | |
| quantize_fn = None | |
| if quantize: | |
| assert HAS_VELOCATOR, "velocator is not installed" | |
| from xelerate.ao.quant import quantize | |
| kwargs = json.loads(kwargs) if kwargs else {} | |
| if lowvram and quantize_on_load_device: | |
| preprocessor = lambda t: ( | |
| t.to(patchers.QuantizedModelPatcher._load_device) | |
| if patchers.QuantizedModelPatcher._load_device is not None | |
| else t | |
| ) | |
| kwargs["preprocessor"] = preprocessor | |
| postprocessor = lambda t: (t.to(torch.device("cpu"))) | |
| kwargs["postprocessor"] = postprocessor | |
| quantize_fn = functools.partial( | |
| quantize, | |
| quant_type=quant_type, | |
| filter_fn=filter_fn, | |
| filter_fn_kwargs=( | |
| json.loads(filter_fn_kwargs) if filter_fn_kwargs else {} | |
| ), | |
| **kwargs, | |
| ) | |
| with patchers.QuantizedModelPatcher._override_defaults( | |
| quantize_fn=quantize_fn, | |
| lowvram=lowvram, | |
| full_load=full_load, | |
| ), utils.disable_load_models_gpu(), unittest.mock.patch.object( | |
| comfy.model_patcher, "ModelPatcher", patchers.QuantizedModelPatcher | |
| ): | |
| model = comfy.sd.load_diffusion_model( | |
| unet_path, model_options=model_options | |
| ) | |
| return (model,) | |
| class VelocatorLoadAndQuantizeClip: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "clip_name1": ([""] + folder_paths.get_filename_list("text_encoders"),), | |
| "clip_name2": ([""] + folder_paths.get_filename_list("text_encoders"),), | |
| "clip_name3": ([""] + folder_paths.get_filename_list("text_encoders"),), | |
| "type": ([member.name.lower() for member in comfy.sd.CLIPType],), | |
| **utils.get_weight_dtype_inputs(), | |
| "lowvram": ("BOOLEAN", {"default": True}), | |
| "full_load": ("BOOLEAN", {"default": True}), | |
| "quantize": ("BOOLEAN", {"default": True}), | |
| "quantize_on_load_device": ("BOOLEAN", {"default": True}), | |
| **get_quant_inputs(), | |
| } | |
| } | |
| RETURN_TYPES = ("CLIP",) | |
| FUNCTION = "load_clip" | |
| CATEGORY = "wavespeed/velocator" | |
| def load_clip( | |
| self, | |
| clip_name1, | |
| clip_name2, | |
| clip_name3, | |
| type, | |
| weight_dtype, | |
| lowvram, | |
| full_load, | |
| quantize, | |
| quantize_on_load_device, | |
| quant_type, | |
| filter_fn, | |
| filter_fn_kwargs, | |
| kwargs, | |
| ): | |
| model_options = {} | |
| if lowvram: | |
| model_options["initial_device"] = torch.device("cpu") | |
| model_options = utils.parse_weight_dtype(model_options, weight_dtype) | |
| clip_paths = [] | |
| clip_type = None | |
| for clip_type_ in comfy.sd.CLIPType: | |
| if clip_type_.name.lower() == type: | |
| clip_type = clip_type_ | |
| break | |
| assert clip_type is not None, f"Invalid clip type: {type}" | |
| for clip_name in [clip_name1, clip_name2, clip_name3]: | |
| if clip_name: | |
| clip_path = folder_paths.get_full_path_or_raise( | |
| "text_encoders", clip_name | |
| ) | |
| clip_paths.append(clip_path) | |
| quantize_fn = None | |
| if quantize: | |
| assert HAS_VELOCATOR, "velocator is not installed" | |
| from xelerate.ao.quant import quantize | |
| kwargs = json.loads(kwargs) if kwargs else {} | |
| if lowvram and quantize_on_load_device: | |
| preprocessor = lambda t: ( | |
| t.to(patchers.QuantizedModelPatcher._load_device) | |
| if patchers.QuantizedModelPatcher._load_device is not None | |
| else t | |
| ) | |
| kwargs["preprocessor"] = preprocessor | |
| postprocessor = lambda t: (t.to(torch.device("cpu"))) | |
| kwargs["postprocessor"] = postprocessor | |
| quantize_fn = functools.partial( | |
| quantize, | |
| quant_type=quant_type, | |
| filter_fn=filter_fn, | |
| filter_fn_kwargs=( | |
| json.loads(filter_fn_kwargs) if filter_fn_kwargs else {} | |
| ), | |
| **kwargs, | |
| ) | |
| with patchers.QuantizedModelPatcher._override_defaults( | |
| quantize_fn=quantize_fn, | |
| lowvram=lowvram, | |
| full_load=full_load, | |
| ), utils.disable_load_models_gpu(), unittest.mock.patch.object( | |
| comfy.model_patcher, "ModelPatcher", patchers.QuantizedModelPatcher | |
| ): | |
| clip = comfy.sd.load_clip( | |
| ckpt_paths=clip_paths, | |
| embedding_directory=folder_paths.get_folder_paths("embeddings"), | |
| clip_type=clip_type, | |
| model_options=model_options, | |
| ) | |
| return (clip,) | |
| class VelocatorQuantizeModel: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": ("MODEL",), | |
| "object_to_patch": ( | |
| "STRING", | |
| { | |
| "default": "diffusion_model", | |
| }, | |
| ), | |
| **get_quant_inputs(), | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "patch" | |
| CATEGORY = "wavespeed/velocator" | |
| def patch( | |
| self, | |
| model, | |
| object_to_patch, | |
| quantize, | |
| quant_type, | |
| filter_fn, | |
| filter_fn_kwargs, | |
| kwargs, | |
| ): | |
| assert HAS_VELOCATOR, "velocator is not installed" | |
| from xelerate.ao.quant import quantize | |
| if quantize: | |
| comfy.model_management.unload_all_models() | |
| comfy.model_management.load_models_gpu( | |
| [model], force_patch_weights=True, force_full_load=True | |
| ) | |
| filter_fn_kwargs = json.loads(filter_fn_kwargs) if filter_fn_kwargs else {} | |
| kwargs = json.loads(kwargs) if kwargs else {} | |
| model = model.clone() | |
| model.add_object_patch( | |
| object_to_patch, | |
| quantize( | |
| model.get_model_object(object_to_patch), | |
| quant_type=quant_type, | |
| filter_fn=filter_fn, | |
| filter_fn_kwargs=filter_fn_kwargs, | |
| **kwargs, | |
| ), | |
| ) | |
| return (model,) | |
| class VelocatorCompileModel: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "model": (utils.any_typ,), | |
| "is_patcher": ( | |
| "BOOLEAN", | |
| { | |
| "default": True, | |
| }, | |
| ), | |
| "object_to_patch": ( | |
| "STRING", | |
| { | |
| "default": "diffusion_model", | |
| }, | |
| ), | |
| "memory_format": ( | |
| ["channels_last", "contiguous_format", "preserve_format"], | |
| ), | |
| "fullgraph": ( | |
| "BOOLEAN", | |
| { | |
| "default": False, | |
| }, | |
| ), | |
| "dynamic": ("BOOLEAN", {"default": False}), | |
| "mode": ( | |
| "STRING", | |
| { | |
| "multiline": True, | |
| "default": "cache-all:max-autotune:low-precision", | |
| }, | |
| ), | |
| "options": ( | |
| "STRING", | |
| { | |
| "multiline": True, | |
| # "default": "{}", | |
| }, | |
| ), | |
| "disable": ( | |
| "BOOLEAN", | |
| { | |
| "default": False, | |
| }, | |
| ), | |
| "backend": ( | |
| "STRING", | |
| { | |
| "default": "velocator", | |
| }, | |
| ), | |
| } | |
| } | |
| RETURN_TYPES = (utils.any_typ,) | |
| FUNCTION = "patch" | |
| CATEGORY = "wavespeed/velocator" | |
| def patch( | |
| self, | |
| model, | |
| is_patcher, | |
| object_to_patch, | |
| memory_format, | |
| fullgraph, | |
| dynamic, | |
| mode, | |
| options, | |
| disable, | |
| backend, | |
| ): | |
| assert HAS_VELOCATOR, "velocator is not installed" | |
| from xelerate.compilers.xelerate_compiler import xelerate_compile | |
| from xelerate.utils.memory_format import apply_memory_format | |
| compile_function = xelerate_compile | |
| memory_format = getattr(torch, memory_format) | |
| mode = mode if mode else None | |
| options = json.loads(options) if options else None | |
| if backend == "velocator": | |
| backend = "xelerate" | |
| if is_patcher: | |
| patcher = model.clone() | |
| else: | |
| patcher = model.patcher | |
| patcher = patcher.clone() | |
| patcher.add_object_patch( | |
| object_to_patch, | |
| compile_function( | |
| apply_memory_format( | |
| patcher.get_model_object(object_to_patch), | |
| memory_format=memory_format, | |
| ), | |
| fullgraph=fullgraph, | |
| dynamic=dynamic, | |
| mode=mode, | |
| options=options, | |
| disable=disable, | |
| backend=backend, | |
| ), | |
| ) | |
| if is_patcher: | |
| return (patcher,) | |
| else: | |
| model.patcher = patcher | |
| return (model,) | |