| | import json |
| | from pathlib import Path |
| | from typing import Literal, Optional |
| |
|
| | import torch |
| | from modules.autoencoder import AutoEncoder, AutoEncoderParams |
| | from modules.conditioner import HFEmbedder |
| | from modules.flux_model import Flux, FluxParams |
| | from safetensors.torch import load_file as load_sft |
| |
|
| | try: |
| | from enum import StrEnum |
| | except: |
| | from enum import Enum |
| |
|
| | class StrEnum(str, Enum): |
| | pass |
| |
|
| |
|
| | from pydantic import BaseModel, ConfigDict |
| | from loguru import logger |
| |
|
| |
|
| | class ModelVersion(StrEnum): |
| | flux_dev = "flux-dev" |
| | flux_schnell = "flux-schnell" |
| |
|
| |
|
| | class QuantizationDtype(StrEnum): |
| | qfloat8 = "qfloat8" |
| | qint2 = "qint2" |
| | qint4 = "qint4" |
| | qint8 = "qint8" |
| | bfloat16 = "bfloat16" |
| | float16 = "float16" |
| |
|
| |
|
| | class ModelSpec(BaseModel): |
| | version: ModelVersion |
| | params: FluxParams |
| | ae_params: AutoEncoderParams |
| | ckpt_path: str | None |
| | |
| | clip_path: str | None = "openai/clip-vit-large-patch14" |
| | ae_path: str | None |
| | repo_id: str | None |
| | repo_flow: str | None |
| | repo_ae: str | None |
| | text_enc_max_length: int = 512 |
| | text_enc_path: str | None |
| | text_enc_device: str | torch.device | None = "cuda:0" |
| | ae_device: str | torch.device | None = "cuda:0" |
| | flux_device: str | torch.device | None = "cuda:0" |
| | flow_dtype: str = "float16" |
| | ae_dtype: str = "bfloat16" |
| | text_enc_dtype: str = "bfloat16" |
| | |
| | num_to_quant: Optional[int] = 20 |
| | quantize_extras: bool = False |
| | compile_extras: bool = False |
| | compile_blocks: bool = False |
| | flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
| | text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
| | ae_quantization_dtype: Optional[QuantizationDtype] = None |
| | clip_quantization_dtype: Optional[QuantizationDtype] = None |
| | offload_text_encoder: bool = False |
| | offload_vae: bool = False |
| | offload_flow: bool = False |
| | prequantized_flow: bool = False |
| |
|
| | |
| | quantize_modulation: bool = True |
| | |
| | quantize_flow_embedder_layers: bool = False |
| |
|
| | model_config: ConfigDict = { |
| | "arbitrary_types_allowed": True, |
| | "use_enum_values": True, |
| | } |
| |
|
| |
|
| | def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: |
| | flow = load_flow_model(config) |
| | ae = load_autoencoder(config) |
| | clip, t5 = load_text_encoders(config) |
| | return flow, ae, clip, t5 |
| |
|
| |
|
| | def parse_device(device: str | torch.device | None) -> torch.device: |
| | if isinstance(device, str): |
| | return torch.device(device) |
| | elif isinstance(device, torch.device): |
| | return device |
| | else: |
| | return torch.device("cuda:0") |
| |
|
| |
|
| | def into_dtype(dtype: str) -> torch.dtype: |
| | if isinstance(dtype, torch.dtype): |
| | return dtype |
| | if dtype == "float16": |
| | return torch.float16 |
| | elif dtype == "bfloat16": |
| | return torch.bfloat16 |
| | elif dtype == "float32": |
| | return torch.float32 |
| | else: |
| | raise ValueError(f"Invalid dtype: {dtype}") |
| |
|
| |
|
| | def into_device(device: str | torch.device | None) -> torch.device: |
| | if isinstance(device, str): |
| | return torch.device(device) |
| | elif isinstance(device, torch.device): |
| | return device |
| | elif isinstance(device, int): |
| | return torch.device(f"cuda:{device}") |
| | else: |
| | return torch.device("cuda:0") |
| |
|
| |
|
| | def load_config( |
| | name: ModelVersion = ModelVersion.flux_dev, |
| | flux_path: str | None = None, |
| | ae_path: str | None = None, |
| | text_enc_path: str | None = None, |
| | text_enc_device: str | torch.device | None = None, |
| | ae_device: str | torch.device | None = None, |
| | flux_device: str | torch.device | None = None, |
| | flow_dtype: str = "float16", |
| | ae_dtype: str = "bfloat16", |
| | text_enc_dtype: str = "bfloat16", |
| | num_to_quant: Optional[int] = 20, |
| | compile_extras: bool = False, |
| | compile_blocks: bool = False, |
| | offload_text_enc: bool = False, |
| | offload_ae: bool = False, |
| | offload_flow: bool = False, |
| | quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None, |
| | quant_ae: bool = False, |
| | prequantized_flow: bool = False, |
| | quantize_modulation: bool = True, |
| | quantize_flow_embedder_layers: bool = False, |
| | ) -> ModelSpec: |
| | """ |
| | Load a model configuration using the passed arguments. |
| | """ |
| | text_enc_device = str(parse_device(text_enc_device)) |
| | ae_device = str(parse_device(ae_device)) |
| | flux_device = str(parse_device(flux_device)) |
| | return ModelSpec( |
| | version=name, |
| | repo_id=( |
| | "black-forest-labs/FLUX.1-dev" |
| | if name == ModelVersion.flux_dev |
| | else "black-forest-labs/FLUX.1-schnell" |
| | ), |
| | repo_flow=( |
| | "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" |
| | ), |
| | repo_ae="ae.sft", |
| | ckpt_path=flux_path, |
| | params=FluxParams( |
| | in_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=name == ModelVersion.flux_dev, |
| | ), |
| | ae_path=ae_path, |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | text_enc_path=text_enc_path, |
| | text_enc_device=text_enc_device, |
| | ae_device=ae_device, |
| | flux_device=flux_device, |
| | flow_dtype=flow_dtype, |
| | ae_dtype=ae_dtype, |
| | text_enc_dtype=text_enc_dtype, |
| | text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, |
| | num_to_quant=num_to_quant, |
| | compile_extras=compile_extras, |
| | compile_blocks=compile_blocks, |
| | offload_flow=offload_flow, |
| | offload_text_encoder=offload_text_enc, |
| | offload_vae=offload_ae, |
| | text_enc_quantization_dtype={ |
| | "float8": QuantizationDtype.qfloat8, |
| | "qint2": QuantizationDtype.qint2, |
| | "qint4": QuantizationDtype.qint4, |
| | "qint8": QuantizationDtype.qint8, |
| | }.get(quant_text_enc, None), |
| | ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None, |
| | prequantized_flow=prequantized_flow, |
| | quantize_modulation=quantize_modulation, |
| | quantize_flow_embedder_layers=quantize_flow_embedder_layers, |
| | ) |
| |
|
| |
|
| | def load_config_from_path(path: str) -> ModelSpec: |
| | path_path = Path(path) |
| | if not path_path.exists(): |
| | raise ValueError(f"Path {path} does not exist") |
| | if not path_path.is_file(): |
| | raise ValueError(f"Path {path} is not a file") |
| | return ModelSpec(**json.loads(path_path.read_text())) |
| |
|
| |
|
| | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
| | if len(missing) > 0 and len(unexpected) > 0: |
| | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| | logger.warning("\n" + "-" * 79 + "\n") |
| | logger.warning( |
| | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
| | ) |
| | elif len(missing) > 0: |
| | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| | elif len(unexpected) > 0: |
| | logger.warning( |
| | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
| | ) |
| |
|
| |
|
| | def load_flow_model(config: ModelSpec) -> Flux: |
| | ckpt_path = config.ckpt_path |
| | FluxClass = Flux |
| |
|
| | with torch.device("meta"): |
| | model = FluxClass(config, dtype=into_dtype(config.flow_dtype)) |
| | if not config.prequantized_flow: |
| | model.type(into_dtype(config.flow_dtype)) |
| |
|
| | if ckpt_path is not None: |
| | |
| | sd = load_sft(ckpt_path, device="cpu") |
| | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) |
| | print_load_warning(missing, unexpected) |
| | if not config.prequantized_flow: |
| | model.type(into_dtype(config.flow_dtype)) |
| | return model |
| |
|
| |
|
| | def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: |
| | clip = HFEmbedder( |
| | config.clip_path, |
| | max_length=77, |
| | torch_dtype=into_dtype(config.text_enc_dtype), |
| | device=into_device(config.text_enc_device).index or 0, |
| | is_clip=True, |
| | quantization_dtype=config.clip_quantization_dtype, |
| | ) |
| | t5 = HFEmbedder( |
| | config.text_enc_path, |
| | max_length=config.text_enc_max_length, |
| | torch_dtype=into_dtype(config.text_enc_dtype), |
| | device=into_device(config.text_enc_device).index or 0, |
| | quantization_dtype=config.text_enc_quantization_dtype, |
| | ) |
| | return clip, t5 |
| |
|
| |
|
| | def load_autoencoder(config: ModelSpec) -> AutoEncoder: |
| | ckpt_path = config.ae_path |
| | with torch.device("meta" if ckpt_path is not None else config.ae_device): |
| | ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype)) |
| |
|
| | if ckpt_path is not None: |
| | sd = load_sft(ckpt_path, device=str(config.ae_device)) |
| | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
| | print_load_warning(missing, unexpected) |
| | ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) |
| | if config.ae_quantization_dtype is not None: |
| | from float8_quantize import recursive_swap_linears |
| |
|
| | recursive_swap_linears(ae) |
| | if config.offload_vae: |
| | ae.to("cpu") |
| | torch.cuda.empty_cache() |
| | return ae |
| |
|
| |
|
| | class LoadedModels(BaseModel): |
| | flow: Flux |
| | ae: AutoEncoder |
| | clip: HFEmbedder |
| | t5: HFEmbedder |
| | config: ModelSpec |
| |
|
| | model_config = { |
| | "arbitrary_types_allowed": True, |
| | "use_enum_values": True, |
| | } |
| |
|
| |
|
| | def load_models_from_config_path( |
| | path: str, |
| | ) -> LoadedModels: |
| | config = load_config_from_path(path) |
| | clip, t5 = load_text_encoders(config) |
| | return LoadedModels( |
| | flow=load_flow_model(config), |
| | ae=load_autoencoder(config), |
| | clip=clip, |
| | t5=t5, |
| | config=config, |
| | ) |
| |
|
| |
|
| | def load_models_from_config(config: ModelSpec) -> LoadedModels: |
| | clip, t5 = load_text_encoders(config) |
| | return LoadedModels( |
| | flow=load_flow_model(config), |
| | ae=load_autoencoder(config), |
| | clip=clip, |
| | t5=t5, |
| | config=config, |
| | ) |
| |
|