|
|
import torch
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
|
|
import folder_paths
|
|
|
import os, requests
|
|
|
|
|
|
def get_path():
|
|
|
if "clipseg" in folder_paths.folder_names_and_paths:
|
|
|
paths = folder_paths.folder_names_and_paths["clipseg"]
|
|
|
return paths[0][0]
|
|
|
else:
|
|
|
|
|
|
path = os.path.dirname(os.path.realpath(__file__)) + "/models"
|
|
|
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_model(path, urlbase):
|
|
|
if os.path.exists(path):
|
|
|
return
|
|
|
for file in ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"]:
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
filepath = path + file
|
|
|
if not os.path.exists(filepath):
|
|
|
with open(filepath, "wb") as f:
|
|
|
print(f"[SwarmClipSeg] Downloading '{file}'...")
|
|
|
f.write(requests.get(f"{urlbase}{file}").content)
|
|
|
|
|
|
|
|
|
class SwarmClipSeg:
|
|
|
@classmethod
|
|
|
def INPUT_TYPES(s):
|
|
|
return {
|
|
|
"required": {
|
|
|
"images": ("IMAGE",),
|
|
|
"match_text": ("STRING", {"multiline": True, "tooltip": "A short description (a few words) to describe something within the image to find and mask."}),
|
|
|
"threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step":0.01, "round": False, "tooltip": "Threshold to apply to the mask, higher values will make the mask more strict. Without sufficient thresholding, CLIPSeg may include random stray content around the edges."}),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
CATEGORY = "SwarmUI/masks"
|
|
|
RETURN_TYPES = ("MASK",)
|
|
|
FUNCTION = "seg"
|
|
|
DESCRIPTION = "Segment an image using CLIPSeg, creating a mask of what part of an image appears to match the given text."
|
|
|
|
|
|
def seg(self, images, match_text, threshold):
|
|
|
|
|
|
i = 255.0 * images[0].cpu().numpy()
|
|
|
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
|
|
|
|
|
path = get_path() + "/clipseg-rd64-refined-fp16-safetensors/"
|
|
|
download_model(path, "https://huggingface.co/mcmonkey/clipseg-rd64-refined-fp16/resolve/main/")
|
|
|
processor = CLIPSegProcessor.from_pretrained(path)
|
|
|
model = CLIPSegForImageSegmentation.from_pretrained(path)
|
|
|
with torch.no_grad():
|
|
|
mask = model(**processor(text=match_text, images=img, return_tensors="pt", padding=True))[0]
|
|
|
mask = torch.nn.functional.threshold(mask.sigmoid(), threshold, 0)
|
|
|
mask -= mask.min()
|
|
|
max = mask.max()
|
|
|
if max > 0:
|
|
|
mask /= max
|
|
|
while mask.ndim < 4:
|
|
|
mask = mask.unsqueeze(0)
|
|
|
mask = torch.nn.functional.interpolate(mask, size=(images.shape[1], images.shape[2]), mode="bilinear").squeeze(0)
|
|
|
return (mask,)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
|
"SwarmClipSeg": SwarmClipSeg,
|
|
|
}
|
|
|
|