Goodis commited on
Commit
efc1c32
·
verified ·
1 Parent(s): 7bc0806

Upload 13 files

Browse files
SwarmRemBg.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+
5
+ class SwarmRemBg:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {
9
+ "required": {
10
+ "images": ("IMAGE",),
11
+ }
12
+ }
13
+
14
+ CATEGORY = "SwarmUI/images"
15
+ RETURN_TYPES = ("IMAGE", "MASK",)
16
+ FUNCTION = "rem"
17
+
18
+ def rem(self, images):
19
+ from rembg import remove
20
+
21
+ output = []
22
+ masks = []
23
+ for image in images:
24
+ i = 255.0 * image.cpu().numpy()
25
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
26
+ img = img.convert("RGBA")
27
+ img = remove(img, post_process_mask=True)
28
+ output.append(np.array(img).astype(np.float32) / 255.0)
29
+ if 'A' in img.getbands():
30
+ mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0
31
+ masks.append(1. - mask)
32
+ else:
33
+ masks.append(np.zeros((64,64), dtype=np.float32))
34
+ return (torch.from_numpy(np.array(output)), torch.from_numpy(np.array(masks)))
35
+
36
+ NODE_CLASS_MAPPINGS = {
37
+ "SwarmRemBg": SwarmRemBg,
38
+ }
SwarmSaveAnimationWS.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy, folder_paths, io, struct, subprocess, os, random, sys, time
2
+ from PIL import Image
3
+ import numpy as np
4
+ from server import PromptServer, BinaryEventTypes
5
+ from imageio_ffmpeg import get_ffmpeg_exe
6
+
7
+ SPECIAL_ID = 12345
8
+ VIDEO_ID = 12346
9
+ FFMPEG_PATH = get_ffmpeg_exe()
10
+
11
+
12
+ class SwarmSaveAnimationWS:
13
+ methods = {"default": 4, "fastest": 0, "slowest": 6}
14
+
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {
18
+ "required": {
19
+ "images": ("IMAGE", ),
20
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
21
+ "lossless": ("BOOLEAN", {"default": True}),
22
+ "quality": ("INT", {"default": 80, "min": 0, "max": 100}),
23
+ "method": (list(s.methods.keys()),),
24
+ "format": (["webp", "gif", "gif-hd", "h264-mp4", "h265-mp4", "webm", "prores"],),
25
+ },
26
+ }
27
+
28
+ CATEGORY = "SwarmUI/video"
29
+ RETURN_TYPES = ()
30
+ FUNCTION = "save_images"
31
+ OUTPUT_NODE = True
32
+
33
+ def save_images(self, images, fps, lossless, quality, method, format):
34
+ method = self.methods.get(method)
35
+ if images.shape[0] == 0:
36
+ return { }
37
+ if images.shape[0] == 1:
38
+ pbar = comfy.utils.ProgressBar(SPECIAL_ID)
39
+ i = 255.0 * images[0].cpu().numpy()
40
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
41
+ pbar.update_absolute(0, SPECIAL_ID, ("PNG", img, None))
42
+ return { }
43
+
44
+ out_img = io.BytesIO()
45
+ if format in ["webp", "gif"]:
46
+ if format == "webp":
47
+ type_num = 3
48
+ else:
49
+ type_num = 4
50
+ pil_images = []
51
+ for image in images:
52
+ i = 255. * image.cpu().numpy()
53
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
54
+ pil_images.append(img)
55
+ pil_images[0].save(out_img, save_all=True, duration=int(1000.0 / fps), append_images=pil_images[1 : len(pil_images)], lossless=lossless, quality=quality, method=method, format=format.upper(), loop=0)
56
+ else:
57
+ i = 255. * images.cpu().numpy()
58
+ raw_images = np.clip(i, 0, 255).astype(np.uint8)
59
+ args = [FFMPEG_PATH, "-v", "error", "-f", "rawvideo", "-pix_fmt", "rgb24",
60
+ "-s", f"{len(raw_images[0][0])}x{len(raw_images[0])}", "-r", str(fps), "-i", "-", "-n" ]
61
+ if format == "h264-mp4":
62
+ args += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-crf", "19"]
63
+ ext = "mp4"
64
+ type_num = 5
65
+ elif format == "h265-mp4":
66
+ args += ["-c:v", "libx265", "-pix_fmt", "yuv420p"]
67
+ ext = "mp4"
68
+ type_num = 5
69
+ elif format == "webm":
70
+ args += ["-pix_fmt", "yuv420p", "-crf", "23"]
71
+ ext = "webm"
72
+ type_num = 6
73
+ elif format == "prores":
74
+ args += ["-c:v", "prores_ks", "-profile:v", "3", "-pix_fmt", "yuv422p10le"]
75
+ ext = "mov"
76
+ type_num = 7
77
+ elif format == "gif-hd":
78
+ args += ["-filter_complex", "split=2 [a][b]; [a] palettegen [pal]; [b] [pal] paletteuse"]
79
+ ext = "gif"
80
+ type_num = 4
81
+ path = folder_paths.get_save_image_path("swarm_tmp_", folder_paths.get_temp_directory())[0]
82
+ rand = '%016x' % random.getrandbits(64)
83
+ file = os.path.join(path, f"swarm_tmp_{rand}.{ext}")
84
+ result = subprocess.run(args + [file], input=raw_images.tobytes(), capture_output=True)
85
+ if result.returncode != 0:
86
+ print(f"ffmpeg failed with return code {result.returncode}", file=sys.stderr)
87
+ f_out = result.stdout.decode("utf-8").strip()
88
+ f_err = result.stderr.decode("utf-8").strip()
89
+ if f_out:
90
+ print("ffmpeg out: " + f_out, file=sys.stderr)
91
+ if f_err:
92
+ print("ffmpeg error: " + f_err, file=sys.stderr)
93
+ raise Exception(f"ffmpeg failed: {f_err}")
94
+ # TODO: Is there a way to get ffmpeg to operate entirely in memory?
95
+ with open(file, "rb") as f:
96
+ out_img.write(f.read())
97
+ os.remove(file)
98
+
99
+ out = io.BytesIO()
100
+ header = struct.pack(">I", type_num)
101
+ out.write(header)
102
+ out.write(out_img.getvalue())
103
+ out.seek(0)
104
+ preview_bytes = out.getvalue()
105
+ server = PromptServer.instance
106
+ server.send_sync("progress", {"value": 12346, "max": 12346}, sid=server.client_id)
107
+ server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=server.client_id)
108
+
109
+ return { }
110
+
111
+ @classmethod
112
+ def IS_CHANGED(s, images, fps, lossless, quality, method, format):
113
+ return time.time()
114
+
115
+
116
+ NODE_CLASS_MAPPINGS = {
117
+ "SwarmSaveAnimationWS": SwarmSaveAnimationWS,
118
+ }
SwarmYolo.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, folder_paths, comfy
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ class SwarmYoloDetection:
6
+ @classmethod
7
+ def INPUT_TYPES(cls):
8
+ return {
9
+ "required": {
10
+ "image": ("IMAGE",),
11
+ "model_name": (folder_paths.get_filename_list("yolov8"), ),
12
+ "index": ("INT", { "default": 0, "min": 0, "max": 256, "step": 1 }),
13
+ },
14
+ "optional": {
15
+ "class_filter": ("STRING", { "default": "", "multiline": False }),
16
+ "sort_order": (["left-right", "right-left", "top-bottom", "bottom-top", "largest-smallest", "smallest-largest"], ),
17
+ "threshold": ("FLOAT", { "default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01 }),
18
+ }
19
+ }
20
+
21
+ CATEGORY = "SwarmUI/masks"
22
+ RETURN_TYPES = ("MASK",)
23
+ FUNCTION = "seg"
24
+
25
+ def seg(self, image, model_name, index, class_filter=None, sort_order="left-right", threshold=0.25):
26
+ # TODO: Batch support?
27
+ i = 255.0 * image[0].cpu().numpy()
28
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
29
+ # TODO: Cache the model in RAM in some way?
30
+ model_path = folder_paths.get_full_path("yolov8", model_name)
31
+ if model_path is None:
32
+ raise ValueError(f"Model {model_name} not found, or yolov8 folder path not defined")
33
+ from ultralytics import YOLO
34
+ model = YOLO(model_path)
35
+ results = model.predict(img, conf=threshold)
36
+ boxes = results[0].boxes
37
+ class_ids = boxes.cls.cpu().numpy() if boxes is not None else []
38
+ selected_classes = None
39
+
40
+ if class_filter and class_filter.strip():
41
+ class_filter_list = [cls_name.strip() for cls_name in class_filter.split(",") if cls_name.strip()]
42
+ label_to_id = {name.lower(): id for id, name in model.names.items()}
43
+ selected_classes = []
44
+ for cls_name in class_filter_list:
45
+ if cls_name.isdigit():
46
+ selected_classes.append(int(cls_name))
47
+ else:
48
+ class_id = label_to_id.get(cls_name.lower())
49
+ if class_id is not None:
50
+ selected_classes.append(class_id)
51
+ else:
52
+ print(f"Class '{cls_name}' not found in the model")
53
+ selected_classes = selected_classes if selected_classes else None
54
+
55
+ masks = results[0].masks
56
+ if masks is not None and selected_classes is not None:
57
+ selected_masks = []
58
+ for i, class_id in enumerate(class_ids):
59
+ if class_id in selected_classes:
60
+ selected_masks.append(masks.data[i].cpu())
61
+ if selected_masks:
62
+ masks = torch.stack(selected_masks)
63
+ else:
64
+ masks = None
65
+
66
+ if masks is None or masks.shape[0] == 0:
67
+ if boxes is None or len(boxes) == 0:
68
+ return (torch.zeros(1, image.shape[1], image.shape[2]), )
69
+ else:
70
+ if selected_classes:
71
+ boxes = [box for i, box in enumerate(boxes) if class_ids[i] in selected_classes]
72
+ masks = torch.zeros((len(boxes), image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu")
73
+ for i, box in enumerate(boxes):
74
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
75
+ masks[i, int(y1):int(y2), int(x1):int(x2)] = 1.0
76
+ else:
77
+ masks = masks.data.cpu()
78
+ if masks is None or masks.shape[0] == 0:
79
+ return (torch.zeros(1, image.shape[1], image.shape[2]), )
80
+
81
+ masks = torch.nn.functional.interpolate(masks.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode="bilinear").squeeze(1)
82
+ if index == 0:
83
+ result = masks[0]
84
+ for i in range(1, len(masks)):
85
+ result = torch.max(result, masks[i])
86
+ return (result.unsqueeze(0), )
87
+ elif index > len(masks):
88
+ return (torch.zeros_like(masks[0]).unsqueeze(0), )
89
+ else:
90
+ sortedindices = []
91
+ for mask in masks:
92
+ match sort_order:
93
+ case "left-right":
94
+ sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
95
+ val = torch.argmax(sum_x).item()
96
+ case "right-left":
97
+ sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
98
+ val = mask.shape[1] - torch.argmax(torch.flip(sum_x, [0])).item() - 1
99
+ case "top-bottom":
100
+ sum_y = (torch.sum(mask, dim=1) != 0).to(dtype=torch.int)
101
+ val = torch.argmax(sum_y).item()
102
+ case "bottom-top":
103
+ sum_y = (torch.sum(mask, dim=1) != 0).to(dtype=torch.int)
104
+ val = mask.shape[0] - torch.argmax(torch.flip(sum_y, [0])).item() - 1
105
+ case "largest-smallest" | "smallest-largest":
106
+ val = torch.sum(mask).item()
107
+ sortedindices.append(val)
108
+ sortedindices = np.argsort(sortedindices)
109
+ if sort_order in ["right-left", "bottom-top", "largest-smallest"]:
110
+ sortedindices = sortedindices[::-1].copy()
111
+ masks = masks[sortedindices]
112
+ return (masks[index - 1].unsqueeze(0), )
113
+
114
+ NODE_CLASS_MAPPINGS = {
115
+ "SwarmYoloDetection": SwarmYoloDetection,
116
+ }
__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+
3
+ NODE_CLASS_MAPPINGS = {}
4
+
5
+ # RemBg doesn't work on all python versions and OS's
6
+ try:
7
+ from . import SwarmRemBg
8
+ NODE_CLASS_MAPPINGS.update(SwarmRemBg.NODE_CLASS_MAPPINGS)
9
+ except ImportError:
10
+ print("Error: [Swarm] RemBg not available")
11
+ traceback.print_exc()
12
+ # This uses FFMPEG which doesn't install itself properly on Macs I guess?
13
+ try:
14
+ from . import SwarmSaveAnimationWS
15
+ NODE_CLASS_MAPPINGS.update(SwarmSaveAnimationWS.NODE_CLASS_MAPPINGS)
16
+ except ImportError:
17
+ print("Error: [Swarm] SaveAnimationWS not available")
18
+ traceback.print_exc()
19
+ # Yolo uses Ultralytics, which is cursed
20
+ try:
21
+ from . import SwarmYolo
22
+ NODE_CLASS_MAPPINGS.update(SwarmYolo.NODE_CLASS_MAPPINGS)
23
+ except ImportError:
24
+ print("Error: [Swarm] Yolo not available")
25
+ traceback.print_exc()
__pycache__/SwarmRemBg.cpython-310.pyc ADDED
Binary file (1.39 kB). View file
 
__pycache__/SwarmRemBg.cpython-313.pyc ADDED
Binary file (2.48 kB). View file
 
__pycache__/SwarmSaveAnimationWS.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
__pycache__/SwarmSaveAnimationWS.cpython-313.pyc ADDED
Binary file (7.39 kB). View file
 
__pycache__/SwarmYolo.cpython-310.pyc ADDED
Binary file (4.06 kB). View file
 
__pycache__/SwarmYolo.cpython-313.pyc ADDED
Binary file (8.02 kB). View file
 
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (722 Bytes). View file
 
__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.13 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ rembg
2
+ dill
3
+ ultralytics