Spaces:
Running
on
Zero
Running
on
Zero
| import contextlib | |
| import copy | |
| import unittest | |
| import comfy.model_management | |
| import comfy.model_patcher | |
| import comfy.sd | |
| import comfy.utils | |
| import torch | |
| class QuantizedModelPatcher(comfy.model_patcher.ModelPatcher): | |
| _object_to_patch_default = None | |
| _quantize_fn_default = None | |
| _lowvram_default = True | |
| _full_load_default = True | |
| _is_quantized_default = False | |
| _load_device = None | |
| _offload_device = None | |
| _disable_load = False | |
| def _override_defaults(cls, **kwargs): | |
| old_defaults = {} | |
| for k in ("object_to_patch", "quantize_fn", "lowvram", "full_load"): | |
| if k in kwargs: | |
| old_defaults[k] = getattr(cls, f"_{k}_default") | |
| setattr(cls, f"_{k}_default", kwargs[k]) | |
| try: | |
| yield | |
| finally: | |
| for k in old_defaults: | |
| setattr(cls, f"_{k}_default", old_defaults[k]) | |
| def _set_disable_load(cls, disable_load=True): | |
| old_disable_load = cls._disable_load | |
| cls._disable_load = disable_load | |
| try: | |
| yield | |
| finally: | |
| cls._disable_load = old_disable_load | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._object_to_patch = QuantizedModelPatcher._object_to_patch_default | |
| self._quantize_fn = QuantizedModelPatcher._quantize_fn_default | |
| self._lowvram = QuantizedModelPatcher._lowvram_default | |
| self._full_load = QuantizedModelPatcher._full_load_default | |
| self._is_quantized = QuantizedModelPatcher._is_quantized_default | |
| def load( | |
| self, device_to=None, force_patch_weights=False, full_load=False, **kwargs | |
| ): | |
| if self._disable_load: | |
| return | |
| if self._is_quantized: | |
| super().load( | |
| device_to=device_to, | |
| force_patch_weights=force_patch_weights, | |
| full_load=full_load, | |
| **kwargs, | |
| ) | |
| return | |
| with unittest.mock.patch.object( | |
| QuantizedModelPatcher, "_load_device", self.load_device | |
| ), unittest.mock.patch.object( | |
| QuantizedModelPatcher, "_offload_device", self.offload_device | |
| ): | |
| # always call `patch_weight_to_device` even for lowvram | |
| super().load( | |
| torch.device("cpu") if self._lowvram else device_to, | |
| force_patch_weights=True, | |
| full_load=self._full_load or full_load, | |
| **kwargs, | |
| ) | |
| if self._quantize_fn is not None: | |
| if self._object_to_patch is None: | |
| target_model = self.model | |
| else: | |
| target_model = comfy.utils.get_attr( | |
| self.model, self._object_to_patch | |
| ) | |
| target_model = self._quantize_fn(target_model) | |
| if self._object_to_patch is None: | |
| self.model = target_model | |
| else: | |
| comfy.utils.set_attr( | |
| self.model, self._object_to_patch, target_model | |
| ) | |
| if self._lowvram: | |
| if device_to.type == "cuda": | |
| torch.cuda.empty_cache() | |
| self.model.to(device_to) | |
| self._is_quantized = True | |
| # def model_size(self): | |
| # return super().model_size() // 2 | |
| def clone(self, *args, **kwargs): | |
| n = QuantizedModelPatcher( | |
| self.model, | |
| self.load_device, | |
| self.offload_device, | |
| self.size, | |
| weight_inplace_update=self.weight_inplace_update, | |
| ) | |
| n.patches = {} | |
| for k in self.patches: | |
| n.patches[k] = self.patches[k][:] | |
| n.patches_uuid = self.patches_uuid | |
| n.object_patches = self.object_patches.copy() | |
| n.model_options = copy.deepcopy(self.model_options) | |
| n.backup = self.backup | |
| n.object_patches_backup = self.object_patches_backup | |
| n._object_to_patch = getattr( | |
| self, "_object_to_patch", QuantizedModelPatcher._object_to_patch_default | |
| ) | |
| n._quantize_fn = getattr( | |
| self, "_quantize_fn", QuantizedModelPatcher._quantize_fn_default | |
| ) | |
| n._lowvram = getattr(self, "_lowvram", QuantizedModelPatcher._lowvram_default) | |
| n._full_load = getattr( | |
| self, "_full_load", QuantizedModelPatcher._full_load_default | |
| ) | |
| n._is_quantized = getattr( | |
| self, "_is_quantized", QuantizedModelPatcher._is_quantized_default | |
| ) | |
| return n | |