import os import copy import numpy as np import tqdm import mediapipe as mp import torch from mediapipe.tasks import python from mediapipe.tasks.python import vision from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run from utils.commons.tensor_utils import convert_to_np from sklearn.neighbors import NearestNeighbors def scatter_np(condition_img, classSeg=5): # def scatter(condition_img, classSeg=19, label_size=(512, 512)): batch, c, height, width = condition_img.shape # if height != label_size[0] or width != label_size[1]: # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest') input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_) # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device) np.put_along_axis(input_label, condition_img, 1, 1) return input_label def scatter(condition_img, classSeg=19): # def scatter(condition_img, classSeg=19, label_size=(512, 512)): batch, c, height, width = condition_img.size() # if height != label_size[0] or width != label_size[1]: # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest') input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device) # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device) return input_label.scatter_(1, condition_img.long(), 1) def encode_segmap_mask_to_image(segmap): # rgb _,h,w = segmap.shape encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255 colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)] for i, color in enumerate(colors): mask = segmap[i].astype(int) index = np.where(mask != 0) encoded_img[index[0], index[1], :] = np.array(color) return encoded_img.astype(np.uint8) def decode_segmap_mask_from_image(encoded_img): # rgb colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)] bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255) hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0) body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255) face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255) clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0) others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0) segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0) return segmap.astype(np.uint8) def read_video_frame(video_name, frame_id): # https://blog.csdn.net/bby1987/article/details/108923361 # frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数 # fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率 # width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度 # height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度 # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置 # video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置 # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0 # video_capture.release() vr = cv2.VideoCapture(video_name) vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id) _, frame = vr.read() return frame def decode_segmap_mask_from_segmap_video_frame(video_frame): # video_frame: 0~255 BGR, obtained by read_video_frame def assign_values(array): remainder = array % 40 # 计算数组中每个值与40的余数 assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder)) return assigned_values segmap = video_frame.mean(-1) segmap = assign_values(segmap) // 40 # [H, W] with value 0~5 segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W] return segmap.astype(np.uint8) def extract_background(img_lst, segmap_lst=None): """ img_lst: list of rgb ndarray """ # only use 1/20 images num_frames = len(img_lst) img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1] if segmap_lst is not None: segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1] assert len(img_lst) == len(segmap_lst) # get H/W h, w = img_lst[0].shape[:2] # nearest neighbors all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() distss = [] for idx, img in enumerate(img_lst): if segmap_lst is not None: segmap = segmap_lst[idx] else: segmap = seg_model._cal_seg_map(img) bg = (segmap[0]).astype(bool) fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) dists, _ = nbrs.kneighbors(all_xys) distss.append(dists) distss = np.stack(distss) max_dist = np.max(distss, 0) max_id = np.argmax(distss, 0) bc_pixs = max_dist > 10 # 5 bc_pixs_id = np.nonzero(bc_pixs) bc_ids = max_id[bc_pixs] num_pixs = distss.shape[1] imgs = np.stack(img_lst).reshape(-1, num_pixs, 3) bg_img = np.zeros((h*w, 3), dtype=np.uint8) bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] bg_img = bg_img.reshape(h, w, 3) max_dist = max_dist.reshape(h, w) bc_pixs = max_dist > 10 # 5 bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) distances, indices = nbrs.kneighbors(bg_xys) bg_fg_xys = fg_xys[indices[:, 0]] bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] return bg_img global_segmenter = None def job_cal_seg_map_for_image(img, segmenter_options=None, segmenter=None): """ 被 MediapipeSegmenter.multiprocess_cal_seg_map_for_a_video所使用, 专门用来处理单个长视频. """ global global_segmenter if segmenter is not None: segmenter_actual = segmenter else: global_segmenter = vision.ImageSegmenter.create_from_options(segmenter_options) if global_segmenter is None else global_segmenter segmenter_actual = global_segmenter mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img) out = segmenter_actual.segment(mp_image) segmap = out.category_mask.numpy_view().copy() # [H, W] segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W] segmap_image = segmap[:, :, None].repeat(3, 2).astype(float) segmap_image = (segmap_image * 40).astype(np.uint8) return segmap_mask, segmap_image class MediapipeSegmenter: def __init__(self): model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite' if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) print("downloading segmenter model from mediapipe...") os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite") os.system(f"mv selfie_multiclass_256x256.tflite {model_path}") print("download success") base_options = python.BaseOptions(model_asset_path=model_path) self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True) self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True) def multiprocess_cal_seg_map_for_a_video(self, imgs, num_workers=4): """ 并行处理单个长视频 imgs: list of rgb array in 0~255 """ segmap_masks = [] segmap_images = [] img_lst = [(self.options, imgs[i]) for i in range(len(imgs))] for (i, res) in multiprocess_run_tqdm(job_cal_seg_map_for_image, args=img_lst, num_workers=num_workers, desc='extracting from a video in multi-process'): segmap_mask, segmap_image = res segmap_masks.append(segmap_mask) segmap_images.append(segmap_image) return segmap_masks, segmap_images def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True): segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter assert return_onehot_mask or return_segmap_image # you should at least return one segmap_masks = [] segmap_images = [] for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."): img = imgs[i] mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img) out = segmenter.segment_for_video(mp_image, 40 * i) segmap = out.category_mask.numpy_view().copy() # [H, W] if return_onehot_mask: segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W] segmap_masks.append(segmap_mask) if return_segmap_image: segmap_image = segmap[:, :, None].repeat(3, 2).astype(float) segmap_image = (segmap_image * 40).astype(np.uint8) segmap_images.append(segmap_image) if return_onehot_mask and return_segmap_image: return segmap_masks, segmap_images elif return_onehot_mask: return segmap_masks elif return_segmap_image: return segmap_images def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True): """ segmenter: vision.ImageSegmenter.create_from_options(options) img: numpy, [H, W, 3], 0~255 segmap: [C, H, W] 0 - background 1 - hair 2 - body-skin 3 - face-skin 4 - clothes 5 - others (accessories) """ assert img.ndim == 3 segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img) out = segmenter.segment(image) segmap = out.category_mask.numpy_view().copy() # [H, W] if return_onehot_mask: segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W] return segmap def _seg_out_img_with_segmap(self, img, segmap, mode='head'): """ img: [h,w,c], img is in 0~255, np """ # img = copy.deepcopy(img) if mode == 'head': selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention # selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5 elif mode == 'person': selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5 img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention elif mode == 'torso': selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5 img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention elif mode == 'torso_with_bg': selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5 img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention elif mode == 'bg': selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention elif mode == 'full': pass else: raise NotImplementedError() return img, selected_mask def _seg_out_img(self, img, segmenter=None, mode='head'): """ imgs [H, W, 3] 0-255 return : person_img [B, 3, H, W] """ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W] return self._seg_out_img_with_segmap(img, segmap, mode=mode) def seg_out_imgs(self, img, mode='head'): """ api for pytorch img, -1~1 img: [B, 3, H, W], -1~1 """ device = img.device img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3] img = ((img + 1) * 127.5).astype(np.uint8) img_lst = [copy.deepcopy(img[i]) for i in range(len(img))] out_lst = [] for im in img_lst: out = self._seg_out_img(im, mode=mode) out_lst.append(out) seg_imgs = np.stack(out_lst) # [B, H, W, 3] seg_imgs = (seg_imgs - 127.5) / 127.5 seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device) return seg_imgs if __name__ == '__main__': import imageio, cv2, tqdm import torchshow as ts img = imageio.imread("1.png") img = cv2.resize(img, (512,512)) seg_model = MediapipeSegmenter() img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2) img = (img-127.5)/127.5 out = seg_model.seg_out_imgs(img, 'torso') ts.save(out,"torso.png") out = seg_model.seg_out_imgs(img, 'head') ts.save(out,"head.png") out = seg_model.seg_out_imgs(img, 'bg') ts.save(out,"bg.png") img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3] img = ((img + 1) * 127.5).astype(np.uint8) bg = extract_background(img) ts.save(bg,"bg2.png")