danhtran2mind commited on
Commit
73f8596
·
verified ·
1 Parent(s): 3c157f4

Delete src/controlnet_image_generator/old-infer.py

Browse files
src/controlnet_image_generator/old-infer.py DELETED
@@ -1,102 +0,0 @@
1
- import cv2
2
- import torch
3
- from PIL import Image
4
- import numpy as np
5
- import yaml
6
- import argparse
7
- from controlnet_aux import OpenposeDetector
8
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
9
- from utils.download import load_image
10
- from utils.plot import image_grid
11
-
12
- def load_config(config_path):
13
- with open(config_path, 'r') as file:
14
- return yaml.safe_load(file)
15
-
16
- def initialize_controlnet(config):
17
- model_id = config['model_id']
18
- local_dir = config.get('local_dir', model_id)
19
- return ControlNetModel.from_pretrained(
20
- local_dir if local_dir != model_id else model_id,
21
- torch_dtype=torch.float16
22
- )
23
-
24
- def initialize_pipeline(controlnet, config):
25
- model_id = config['model_id']
26
- local_dir = config.get('local_dir', model_id)
27
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
28
- local_dir if local_dir != model_id else model_id,
29
- controlnet=controlnet,
30
- torch_dtype=torch.float16
31
- )
32
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
33
- return pipe
34
-
35
- def setup_device(pipe):
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
- if device == "cuda":
38
- pipe.enable_model_cpu_offload()
39
- pipe.to(device)
40
- return device
41
-
42
- def generate_images(pipe, prompts, pose_images, generators, negative_prompts, num_steps):
43
- return pipe(
44
- prompts,
45
- pose_images,
46
- negative_prompt=negative_prompts,
47
- generator=generators,
48
- num_inference_steps=num_steps
49
- ).images
50
-
51
- def infer(args):
52
- # Load configuration
53
- configs = load_config(args.config_path)
54
-
55
- # Initialize models
56
- controlnet_detector = OpenposeDetector.from_pretrained(
57
- configs[2]['model_id'] # lllyasviel/ControlNet
58
- )
59
- controlnet = initialize_controlnet(configs[0])
60
- pipe = initialize_pipeline(controlnet, configs[1])
61
-
62
- # Setup device
63
- device = setup_device(pipe)
64
-
65
- # Load and process image
66
- demo_image = load_image(args.image_url)
67
- poses = [controlnet_detector(demo_image)]
68
-
69
- # Generate images
70
- generators = [torch.Generator(device="cpu").manual_seed(args.seed) for _ in range(len(poses))]
71
-
72
- output_images = generate_images(
73
- pipe,
74
- [args.prompt] * len(generators),
75
- poses,
76
- generators,
77
- [args.negative_prompt] * len(generators),
78
- args.num_steps
79
- )
80
-
81
- # Display results
82
- # image_grid(output_images, 2, 2)
83
-
84
- if __name__ == "__main__":
85
- parser = argparse.ArgumentParser(description="ControlNet image generation with pose detection")
86
- parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml",
87
- help="Path to configuration YAML file")
88
- parser.add_argument("--image_url", type=str,
89
- default="https://huggingface.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg",
90
- help="URL of input image")
91
- parser.add_argument("--prompt", type=str, default="a man is doing yoga",
92
- help="Text prompt for image generation")
93
- parser.add_argument("--negative_prompt", type=str,
94
- default="monochrome, lowres, bad anatomy, worst quality, low quality",
95
- help="Negative prompt for image generation")
96
- parser.add_argument("--num_steps", type=int, default=20,
97
- help="Number of inference steps")
98
- parser.add_argument("--seed", type=int, default=2,
99
- help="Random seed for generation")
100
- # return parser.parse_args()
101
- args = parser.parse_args()
102
- infer(args)