| | import re |
| | from typing import Optional, OrderedDict, Tuple, TypeAlias, Union |
| | import torch |
| | from loguru import logger |
| | from safetensors.torch import load_file |
| | from tqdm import tqdm |
| | from torch import nn |
| |
|
| | try: |
| | from cublas_ops import CublasLinear |
| | except Exception as e: |
| | CublasLinear = type(None) |
| | from float8_quantize import F8Linear |
| | from modules.flux_model import Flux |
| |
|
| | path_regex = re.compile(r"/|\\") |
| |
|
| | StateDict: TypeAlias = OrderedDict[str, torch.Tensor] |
| |
|
| |
|
| | class LoraWeights: |
| | def __init__( |
| | self, |
| | weights: StateDict, |
| | path: str, |
| | name: str = None, |
| | scale: float = 1.0, |
| | ) -> None: |
| | self.path = path |
| | self.weights = weights |
| | self.name = name if name else path_regex.split(path)[-1] |
| | self.scale = scale |
| |
|
| |
|
| | def swap_scale_shift(weight): |
| | scale, shift = weight.chunk(2, dim=0) |
| | new_weight = torch.cat([shift, scale], dim=0) |
| | return new_weight |
| |
|
| |
|
| | def check_if_lora_exists(state_dict, lora_name): |
| | subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0] |
| | for key in state_dict.keys(): |
| | if subkey in key: |
| | return subkey |
| | return False |
| |
|
| |
|
| | def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name): |
| | if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False: |
| | weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k] |
| | for key in weights_to_pop: |
| | key_replacement = key.replace( |
| | original_stubkey, flux_layer_name.replace(".weight", "") |
| | ) |
| | new_state_dict[key_replacement] = state_dict.pop(key) |
| | return new_state_dict, state_dict |
| | else: |
| | return new_state_dict, state_dict |
| |
|
| |
|
| | def convert_diffusers_to_flux_transformer_checkpoint( |
| | diffusers_state_dict, |
| | num_layers, |
| | num_single_layers, |
| | has_guidance=True, |
| | prefix="", |
| | ): |
| | original_state_dict = {} |
| |
|
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}time_text_embed.timestep_embedder.linear_1.weight", |
| | "time_in.in_layer.weight", |
| | ) |
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}time_text_embed.text_embedder.linear_1.weight", |
| | "vector_in.in_layer.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}time_text_embed.text_embedder.linear_2.weight", |
| | "vector_in.out_layer.weight", |
| | ) |
| |
|
| | if has_guidance: |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}time_text_embed.guidance_embedder.linear_1.weight", |
| | "guidance_in.in_layer.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}time_text_embed.guidance_embedder.linear_2.weight", |
| | "guidance_in.out_layer.weight", |
| | ) |
| |
|
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}context_embedder.weight", |
| | "txt_in.weight", |
| | ) |
| |
|
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}x_embedder.weight", |
| | "img_in.weight", |
| | ) |
| | |
| | for i in range(num_layers): |
| | block_prefix = f"transformer_blocks.{i}." |
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}norm1.linear.weight", |
| | f"double_blocks.{i}.img_mod.lin.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}norm1_context.linear.weight", |
| | f"double_blocks.{i}.txt_mod.lin.weight", |
| | ) |
| |
|
| | |
| | temp_dict = {} |
| |
|
| | expected_shape_qkv_a = None |
| | expected_shape_qkv_b = None |
| | expected_shape_add_qkv_a = None |
| | expected_shape_add_qkv_b = None |
| | dtype = None |
| | device = None |
| |
|
| | for component in [ |
| | "to_q", |
| | "to_k", |
| | "to_v", |
| | "add_q_proj", |
| | "add_k_proj", |
| | "add_v_proj", |
| | ]: |
| |
|
| | sample_component_A_key = ( |
| | f"{prefix}{block_prefix}attn.{component}.lora_A.weight" |
| | ) |
| | sample_component_B_key = ( |
| | f"{prefix}{block_prefix}attn.{component}.lora_B.weight" |
| | ) |
| | if ( |
| | sample_component_A_key in diffusers_state_dict |
| | and sample_component_B_key in diffusers_state_dict |
| | ): |
| | sample_component_A = diffusers_state_dict.pop(sample_component_A_key) |
| | sample_component_B = diffusers_state_dict.pop(sample_component_B_key) |
| | temp_dict[f"{component}"] = [sample_component_A, sample_component_B] |
| | if expected_shape_qkv_a is None and not component.startswith("add_"): |
| | expected_shape_qkv_a = sample_component_A.shape |
| | expected_shape_qkv_b = sample_component_B.shape |
| | dtype = sample_component_A.dtype |
| | device = sample_component_A.device |
| | if expected_shape_add_qkv_a is None and component.startswith("add_"): |
| | expected_shape_add_qkv_a = sample_component_A.shape |
| | expected_shape_add_qkv_b = sample_component_B.shape |
| | dtype = sample_component_A.dtype |
| | device = sample_component_A.device |
| | else: |
| | logger.info( |
| | f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}" |
| | ) |
| | temp_dict[f"{component}"] = [None, None] |
| |
|
| | if device is not None: |
| | if expected_shape_qkv_a is not None: |
| |
|
| | if (sq := temp_dict["to_q"])[0] is not None: |
| | sample_q_A, sample_q_B = sq |
| | else: |
| | sample_q_A, sample_q_B = [ |
| | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | if (sq := temp_dict["to_k"])[0] is not None: |
| | sample_k_A, sample_k_B = sq |
| | else: |
| | sample_k_A, sample_k_B = [ |
| | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | if (sq := temp_dict["to_v"])[0] is not None: |
| | sample_v_A, sample_v_B = sq |
| | else: |
| | sample_v_A, sample_v_B = [ |
| | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = ( |
| | torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0) |
| | ) |
| | original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = ( |
| | torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0) |
| | ) |
| | if expected_shape_add_qkv_a is not None: |
| |
|
| | if (sq := temp_dict["add_q_proj"])[0] is not None: |
| | context_q_A, context_q_B = sq |
| | else: |
| | context_q_A, context_q_B = [ |
| | torch.zeros( |
| | expected_shape_add_qkv_a, dtype=dtype, device=device |
| | ), |
| | torch.zeros( |
| | expected_shape_add_qkv_b, dtype=dtype, device=device |
| | ), |
| | ] |
| | if (sq := temp_dict["add_k_proj"])[0] is not None: |
| | context_k_A, context_k_B = sq |
| | else: |
| | context_k_A, context_k_B = [ |
| | torch.zeros( |
| | expected_shape_add_qkv_a, dtype=dtype, device=device |
| | ), |
| | torch.zeros( |
| | expected_shape_add_qkv_b, dtype=dtype, device=device |
| | ), |
| | ] |
| | if (sq := temp_dict["add_v_proj"])[0] is not None: |
| | context_v_A, context_v_B = sq |
| | else: |
| | context_v_A, context_v_B = [ |
| | torch.zeros( |
| | expected_shape_add_qkv_a, dtype=dtype, device=device |
| | ), |
| | torch.zeros( |
| | expected_shape_add_qkv_b, dtype=dtype, device=device |
| | ), |
| | ] |
| |
|
| | original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = ( |
| | torch.cat([context_q_A, context_k_A, context_v_A], dim=0) |
| | ) |
| | original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = ( |
| | torch.cat([context_q_B, context_k_B, context_v_B], dim=0) |
| | ) |
| |
|
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.norm_q.weight", |
| | f"double_blocks.{i}.img_attn.norm.query_norm.scale", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.norm_k.weight", |
| | f"double_blocks.{i}.img_attn.norm.key_norm.scale", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.norm_added_q.weight", |
| | f"double_blocks.{i}.txt_attn.norm.query_norm.scale", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.norm_added_k.weight", |
| | f"double_blocks.{i}.txt_attn.norm.key_norm.scale", |
| | ) |
| |
|
| | |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}ff.net.0.proj.weight", |
| | f"double_blocks.{i}.img_mlp.0.weight", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}ff.net.2.weight", |
| | f"double_blocks.{i}.img_mlp.2.weight", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}ff_context.net.0.proj.weight", |
| | f"double_blocks.{i}.txt_mlp.0.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}ff_context.net.2.weight", |
| | f"double_blocks.{i}.txt_mlp.2.weight", |
| | ) |
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.to_out.0.weight", |
| | f"double_blocks.{i}.img_attn.proj.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}attn.to_add_out.weight", |
| | f"double_blocks.{i}.txt_attn.proj.weight", |
| | ) |
| |
|
| | |
| | for i in range(num_single_layers): |
| | block_prefix = f"single_transformer_blocks.{i}." |
| | |
| | key_norm = f"{prefix}{block_prefix}norm.linear.weight" |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | key_norm, |
| | f"single_blocks.{i}.modulation.lin.weight", |
| | ) |
| |
|
| | has_q, has_k, has_v, has_mlp = False, False, False, False |
| | shape_qkv_a = None |
| | shape_qkv_b = None |
| | |
| | q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight") |
| | q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight") |
| | if q_A is not None and q_B is not None: |
| | has_q = True |
| | shape_qkv_a = q_A.shape |
| | shape_qkv_b = q_B.shape |
| | k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight") |
| | k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight") |
| | if k_A is not None and k_B is not None: |
| | has_k = True |
| | shape_qkv_a = k_A.shape |
| | shape_qkv_b = k_B.shape |
| | v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight") |
| | v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight") |
| | if v_A is not None and v_B is not None: |
| | has_v = True |
| | shape_qkv_a = v_A.shape |
| | shape_qkv_b = v_B.shape |
| | mlp_A = diffusers_state_dict.pop( |
| | f"{prefix}{block_prefix}proj_mlp.lora_A.weight" |
| | ) |
| | mlp_B = diffusers_state_dict.pop( |
| | f"{prefix}{block_prefix}proj_mlp.lora_B.weight" |
| | ) |
| | if mlp_A is not None and mlp_B is not None: |
| | has_mlp = True |
| | shape_qkv_a = mlp_A.shape |
| | shape_qkv_b = mlp_B.shape |
| | if any([has_q, has_k, has_v, has_mlp]): |
| | if not has_q: |
| | q_A, q_B = [ |
| | torch.zeros(shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | if not has_k: |
| | k_A, k_B = [ |
| | torch.zeros(shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | if not has_v: |
| | v_A, v_B = [ |
| | torch.zeros(shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | if not has_mlp: |
| | mlp_A, mlp_B = [ |
| | torch.zeros(shape_qkv_a, dtype=dtype, device=device), |
| | torch.zeros(shape_qkv_b, dtype=dtype, device=device), |
| | ] |
| | original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat( |
| | [q_A, k_A, v_A, mlp_A], dim=0 |
| | ) |
| | original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat( |
| | [q_B, k_B, v_B, mlp_B], dim=0 |
| | ) |
| |
|
| | |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}{block_prefix}proj_out.weight", |
| | f"single_blocks.{i}.linear2.weight", |
| | ) |
| |
|
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}proj_out.weight", |
| | "final_layer.linear.weight", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}proj_out.bias", |
| | "final_layer.linear.bias", |
| | ) |
| | original_state_dict, diffusers_state_dict = convert_if_lora_exists( |
| | original_state_dict, |
| | diffusers_state_dict, |
| | f"{prefix}norm_out.linear.weight", |
| | "final_layer.adaLN_modulation.1.weight", |
| | ) |
| | if len(list(diffusers_state_dict.keys())) > 0: |
| | logger.warning("Unexpected keys:", diffusers_state_dict.keys()) |
| |
|
| | return original_state_dict |
| |
|
| |
|
| | def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict: |
| | """ |
| | Convert the state dict from the original Flux checkpoint format to the new format. |
| | |
| | Args: |
| | original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict. |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: The converted state dict in the new format. |
| | """ |
| | sd = { |
| | k.replace("lora_unet_", "") |
| | .replace("double_blocks_", "double_blocks.") |
| | .replace("single_blocks_", "single_blocks.") |
| | .replace("_img_attn_", ".img_attn.") |
| | .replace("_txt_attn_", ".txt_attn.") |
| | .replace("_img_mod_", ".img_mod.") |
| | .replace("_txt_mod_", ".txt_mod.") |
| | .replace("_img_mlp_", ".img_mlp.") |
| | .replace("_txt_mlp_", ".txt_mlp.") |
| | .replace("_linear1", ".linear1") |
| | .replace("_linear2", ".linear2") |
| | .replace("_modulation_", ".modulation.") |
| | .replace("lora_up", "lora_B") |
| | .replace("lora_down", "lora_A"): v |
| | for k, v in original_state_dict.items() |
| | if "lora" in k |
| | } |
| | return sd |
| |
|
| |
|
| | def get_module_for_key( |
| | key: str, model: Flux |
| | ) -> F8Linear | torch.nn.Linear | CublasLinear: |
| | parts = key.split(".") |
| | module = model |
| | for part in parts: |
| | module = getattr(module, part) |
| | return module |
| |
|
| |
|
| | def get_lora_for_key( |
| | key: str, lora_weights: dict |
| | ) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: |
| | """ |
| | Get LoRA weights for a specific key. |
| | |
| | Args: |
| | key (str): The key to look up in the LoRA weights. |
| | lora_weights (dict): Dictionary containing LoRA weights. |
| | |
| | Returns: |
| | Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise. |
| | """ |
| | prefix = key.split(".lora")[0] |
| | lora_A = lora_weights.get(f"{prefix}.lora_A.weight") |
| | lora_B = lora_weights.get(f"{prefix}.lora_B.weight") |
| | alpha = lora_weights.get(f"{prefix}.alpha") |
| |
|
| | if lora_A is None or lora_B is None: |
| | return None |
| | return lora_A, lora_B, alpha |
| |
|
| |
|
| | def get_module_for_key( |
| | key: str, model: Flux |
| | ) -> F8Linear | torch.nn.Linear | CublasLinear: |
| | parts = key.split(".") |
| | module = model |
| | for part in parts: |
| | module = getattr(module, part) |
| | return module |
| |
|
| |
|
| | def calculate_lora_weight( |
| | lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]], |
| | rank: Optional[int] = None, |
| | lora_scale: float = 1.0, |
| | device: Optional[Union[torch.device, int, str]] = None, |
| | ): |
| | lora_A, lora_B, alpha = lora_weights |
| | if device is None: |
| | device = lora_A.device |
| |
|
| | uneven_rank = lora_B.shape[1] != lora_A.shape[0] |
| | rank_diff = lora_A.shape[0] / lora_B.shape[1] |
| |
|
| | if rank is None: |
| | rank = lora_B.shape[1] |
| | if alpha is None: |
| | alpha = rank |
| |
|
| | dtype = torch.float32 |
| | w_up = lora_A.to(dtype=dtype, device=device) |
| | w_down = lora_B.to(dtype=dtype, device=device) |
| |
|
| | if alpha != rank: |
| | w_up = w_up * alpha / rank |
| | if uneven_rank: |
| | |
| | |
| | fused_lora = torch.zeros( |
| | (lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype |
| | ) |
| | w_up = w_up.chunk(int(rank_diff), dim=0) |
| | for w_up_chunk in w_up: |
| | fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk)) |
| | else: |
| | fused_lora = lora_scale * torch.mm(w_down, w_up) |
| | return fused_lora |
| |
|
| |
|
| | @torch.inference_mode() |
| | def unfuse_lora_weight_from_module( |
| | fused_weight: torch.Tensor, |
| | lora_weights: dict, |
| | rank: Optional[int] = None, |
| | lora_scale: float = 1.0, |
| | ): |
| | w_dtype = fused_weight.dtype |
| | dtype = torch.float32 |
| | device = fused_weight.device |
| |
|
| | fused_weight = fused_weight.to(dtype=dtype, device=device) |
| | fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device) |
| | module_weight = fused_weight - fused_lora |
| | return module_weight.to(dtype=w_dtype, device=device) |
| |
|
| |
|
| | @torch.inference_mode() |
| | def apply_lora_weight_to_module( |
| | module_weight: torch.Tensor, |
| | lora_weights: dict, |
| | rank: int = None, |
| | lora_scale: float = 1.0, |
| | ): |
| | w_dtype = module_weight.dtype |
| | dtype = torch.float32 |
| | device = module_weight.device |
| |
|
| | fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device) |
| | fused_weight = module_weight.to(dtype=dtype) + fused_lora |
| | return fused_weight.to(dtype=w_dtype, device=device) |
| |
|
| |
|
| | def resolve_lora_state_dict(lora_weights, has_guidance: bool = True): |
| | check_if_starts_with_transformer = [ |
| | k for k in lora_weights.keys() if k.startswith("transformer.") |
| | ] |
| | if len(check_if_starts_with_transformer) > 0: |
| | lora_weights = convert_diffusers_to_flux_transformer_checkpoint( |
| | lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer." |
| | ) |
| | else: |
| | lora_weights = convert_from_original_flux_checkpoint(lora_weights) |
| | logger.info("LoRA weights loaded") |
| | logger.debug("Extracting keys") |
| | keys_without_ab = list( |
| | set( |
| | [ |
| | key.replace(".lora_A.weight", "") |
| | .replace(".lora_B.weight", "") |
| | .replace(".lora_A", "") |
| | .replace(".lora_B", "") |
| | .replace(".alpha", "") |
| | for key in lora_weights.keys() |
| | ] |
| | ) |
| | ) |
| | logger.debug("Keys extracted") |
| | return keys_without_ab, lora_weights |
| |
|
| |
|
| | def get_lora_weights(lora_path: str | StateDict): |
| | if isinstance(lora_path, (dict, LoraWeights)): |
| | return lora_path, True |
| | else: |
| | return load_file(lora_path, "cpu"), False |
| |
|
| |
|
| | def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]): |
| | dtype = linear.weight.dtype |
| | weight_is_f8 = False |
| | if isinstance(linear, F8Linear): |
| | weight_is_f8 = True |
| | weight = ( |
| | linear.float8_data.clone() |
| | .detach() |
| | .float() |
| | .mul(linear.scale_reciprocal) |
| | .to(linear.weight.device) |
| | ) |
| | elif isinstance(linear, torch.nn.Linear): |
| | weight = linear.weight.clone().detach().float() |
| | elif isinstance(linear, CublasLinear) and CublasLinear != type(None): |
| | weight = linear.weight.clone().detach().float() |
| | return weight, weight_is_f8, dtype |
| |
|
| |
|
| | @torch.inference_mode() |
| | def apply_lora_to_model( |
| | model: Flux, |
| | lora_path: str | StateDict, |
| | lora_scale: float = 1.0, |
| | return_lora_resolved: bool = False, |
| | ) -> Flux: |
| | has_guidance = model.params.guidance_embed |
| | logger.info(f"Loading LoRA weights for {lora_path}") |
| | lora_weights, already_loaded = get_lora_weights(lora_path) |
| |
|
| | if not already_loaded: |
| | keys_without_ab, lora_weights = resolve_lora_state_dict( |
| | lora_weights, has_guidance |
| | ) |
| | elif isinstance(lora_weights, LoraWeights): |
| | b_ = lora_weights |
| | lora_weights = b_.weights |
| | keys_without_ab = list( |
| | set( |
| | [ |
| | key.replace(".lora_A.weight", "") |
| | .replace(".lora_B.weight", "") |
| | .replace(".lora_A", "") |
| | .replace(".lora_B", "") |
| | .replace(".alpha", "") |
| | for key in lora_weights.keys() |
| | ] |
| | ) |
| | ) |
| | else: |
| | lora_weights = lora_weights |
| | keys_without_ab = list( |
| | set( |
| | [ |
| | key.replace(".lora_A.weight", "") |
| | .replace(".lora_B.weight", "") |
| | .replace(".lora_A", "") |
| | .replace(".lora_B", "") |
| | .replace(".alpha", "") |
| | for key in lora_weights.keys() |
| | ] |
| | ) |
| | ) |
| | for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)): |
| | module = get_module_for_key(key, model) |
| | weight, is_f8, dtype = extract_weight_from_linear(module) |
| | lora_sd = get_lora_for_key(key, lora_weights) |
| | if lora_sd is None: |
| | |
| | continue |
| | weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale) |
| | if is_f8: |
| | module.set_weight_tensor(weight.type(dtype)) |
| | else: |
| | module.weight.data = weight.type(dtype) |
| | logger.success("Lora applied") |
| | if return_lora_resolved: |
| | return model, lora_weights |
| | return model |
| |
|
| |
|
| | def remove_lora_from_module( |
| | model: Flux, |
| | lora_path: str | StateDict, |
| | lora_scale: float = 1.0, |
| | ): |
| | has_guidance = model.params.guidance_embed |
| | logger.info(f"Loading LoRA weights for {lora_path}") |
| | lora_weights, already_loaded = get_lora_weights(lora_path) |
| |
|
| | if not already_loaded: |
| | keys_without_ab, lora_weights = resolve_lora_state_dict( |
| | lora_weights, has_guidance |
| | ) |
| | elif isinstance(lora_weights, LoraWeights): |
| | b_ = lora_weights |
| | lora_weights = b_.weights |
| | keys_without_ab = list( |
| | set( |
| | [ |
| | key.replace(".lora_A.weight", "") |
| | .replace(".lora_B.weight", "") |
| | .replace(".lora_A", "") |
| | .replace(".lora_B", "") |
| | .replace(".alpha", "") |
| | for key in lora_weights.keys() |
| | ] |
| | ) |
| | ) |
| | lora_scale = b_.scale |
| | else: |
| | lora_weights = lora_weights |
| | keys_without_ab = list( |
| | set( |
| | [ |
| | key.replace(".lora_A.weight", "") |
| | .replace(".lora_B.weight", "") |
| | .replace(".lora_A", "") |
| | .replace(".lora_B", "") |
| | .replace(".alpha", "") |
| | for key in lora_weights.keys() |
| | ] |
| | ) |
| | ) |
| |
|
| | for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)): |
| | module = get_module_for_key(key, model) |
| | weight, is_f8, dtype = extract_weight_from_linear(module) |
| | lora_sd = get_lora_for_key(key, lora_weights) |
| | if lora_sd is None: |
| | |
| | continue |
| | weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale) |
| | if is_f8: |
| | module.set_weight_tensor(weight.type(dtype)) |
| | else: |
| | module.weight.data = weight.type(dtype) |
| | logger.success("Lora unfused") |
| | return model |
| |
|