LPX55 commited on
Commit
dd7e749
Β·
verified Β·
1 Parent(s): 2574774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -13,16 +13,26 @@ from transformers import AutoProcessor, pipeline, AutoModelForMaskGeneration
13
  from diffusers.models.attention_processor import Attention
14
  from dataclasses import dataclass
15
  from typing import Any, List, Dict, Optional, Union, Tuple
 
 
16
 
17
  # Ensure that the minimal version of diffusers is installed
18
  check_min_version("0.30.2")
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
  os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
21
-
 
 
 
 
 
 
 
22
  # Load necessary models and processors
23
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
24
  pipe = FluxControlNetInpaintingPipeline.from_pretrained(
25
- "black-forest-labs/FLUX.1-dev",
 
26
  controlnet=controlnet,
27
  torch_dtype=torch.bfloat16,
28
  use_safetensors=True,
 
13
  from diffusers.models.attention_processor import Attention
14
  from dataclasses import dataclass
15
  from typing import Any, List, Dict, Optional, Union, Tuple
16
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
17
+ from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
18
 
19
  # Ensure that the minimal version of diffusers is installed
20
  check_min_version("0.30.2")
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
23
+ quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
24
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(
25
+ "black-forest-labs/FLUX.1-dev",
26
+ subfolder="transformer",
27
+ quantization_config=quant_config,
28
+ torch_dtype=torch.bfloat16,
29
+ token=HF_TOKEN
30
+ )
31
  # Load necessary models and processors
32
  controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
33
  pipe = FluxControlNetInpaintingPipeline.from_pretrained(
34
+ "LPX55/FLUX.1-merged_uncensored",
35
+ transformer=transformer_8bit,
36
  controlnet=controlnet,
37
  torch_dtype=torch.bfloat16,
38
  use_safetensors=True,