Spaces:
Running
Running
| """ | |
| Copyright © 2022 Howard Hughes Medical Institute, | |
| Authored by Carsen Stringer and Marius Pachitariu. | |
| Redistribution and use in source and binary forms, with or without | |
| modification, are permitted provided that the following conditions are met: | |
| 1. Redistributions of source code must retain the above copyright notice, | |
| this list of conditions and the following disclaimer. | |
| 2. Redistributions in binary form must reproduce the above copyright notice, | |
| this list of conditions and the following disclaimer in the documentation | |
| and/or other materials provided with the distribution. | |
| 3. Neither the name of HHMI nor the names of its contributors may be used to | |
| endorse or promote products derived from this software without specific | |
| prior written permission. | |
| THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
| ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |
| LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
| CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
| SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
| INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
| CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |
| ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
| POSSIBILITY OF SUCH DAMAGE. | |
| -------------------------------------------------------------------------- | |
| MEDIAR Prediction uses CellPose's Gradient Flow Tracking. | |
| This code is adapted from the following codes: | |
| [1] https://github.com/MouseLand/cellpose/blob/main/cellpose/utils.py | |
| [2] https://github.com/MouseLand/cellpose/blob/main/cellpose/dynamics.py | |
| [3] https://github.com/MouseLand/cellpose/blob/main/cellpose/metrics.py | |
| """ | |
| import torch | |
| from torch.nn.functional import grid_sample | |
| import numpy as np | |
| import fastremap | |
| from skimage import morphology | |
| from scipy.ndimage import mean, find_objects | |
| from scipy.ndimage.filters import maximum_filter1d | |
| torch_GPU = torch.device("cuda") | |
| torch_CPU = torch.device("cpu") | |
| def labels_to_flows(labels, use_gpu=False, device=None, redo_flows=False): | |
| """ | |
| Convert labels (list of masks or flows) to flows for training model | |
| """ | |
| # Labels b x 1 x h x w | |
| labels = labels.cpu().numpy().astype(np.int16) | |
| nimg = len(labels) | |
| if labels[0].ndim < 3: | |
| labels = [labels[n][np.newaxis, :, :] for n in range(nimg)] | |
| # Flows need to be recomputed | |
| if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: | |
| # compute flows; labels are fixed here to be unique, so they need to be passed back | |
| # make sure labels are unique! | |
| labels = [fastremap.renumber(label, in_place=True)[0] for label in labels] | |
| veci = [ | |
| masks_to_flows(labels[n][0], use_gpu=use_gpu, device=device) | |
| for n in range(nimg) | |
| ] | |
| # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations) | |
| flows = [ | |
| np.concatenate((labels[n], labels[n] > 0.5, veci[n]), axis=0).astype( | |
| np.float32 | |
| ) | |
| for n in range(nimg) | |
| ] | |
| return np.array(flows) | |
| def compute_masks( | |
| dP, | |
| cellprob, | |
| p=None, | |
| niter=200, | |
| cellprob_threshold=0.4, | |
| flow_threshold=0.4, | |
| interp=True, | |
| resize=None, | |
| use_gpu=False, | |
| device=None, | |
| ): | |
| """compute masks using dynamics from dP, cellprob, and boundary""" | |
| cp_mask = cellprob > cellprob_threshold | |
| cp_mask = morphology.remove_small_holes(cp_mask, area_threshold=16) | |
| cp_mask = morphology.remove_small_objects(cp_mask, min_size=16) | |
| if np.any(cp_mask): # mask at this point is a cell cluster binary map, not labels | |
| # follow flows | |
| if p is None: | |
| p, inds = follow_flows( | |
| dP * cp_mask / 5.0, | |
| niter=niter, | |
| interp=interp, | |
| use_gpu=use_gpu, | |
| device=device, | |
| ) | |
| if inds is None: | |
| shape = resize if resize is not None else cellprob.shape | |
| mask = np.zeros(shape, np.uint16) | |
| p = np.zeros((len(shape), *shape), np.uint16) | |
| return mask, p | |
| # calculate masks | |
| mask = get_masks(p, iscell=cp_mask) | |
| # flow thresholding factored out of get_masks | |
| shape0 = p.shape[1:] | |
| if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0: | |
| # make sure labels are unique at output of get_masks | |
| mask = remove_bad_flow_masks( | |
| mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device | |
| ) | |
| else: # nothing to compute, just make it compatible | |
| shape = resize if resize is not None else cellprob.shape | |
| mask = np.zeros(shape, np.uint16) | |
| p = np.zeros((len(shape), *shape), np.uint16) | |
| return mask, p | |
| def _extend_centers_gpu( | |
| neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device("cuda") | |
| ): | |
| if device is not None: | |
| device = device | |
| nimg = neighbors.shape[0] // 9 | |
| pt = torch.from_numpy(neighbors).to(device) | |
| T = torch.zeros((nimg, Ly, Lx), dtype=torch.double, device=device) | |
| meds = torch.from_numpy(centers.astype(int)).to(device).long() | |
| isneigh = torch.from_numpy(isneighbor).to(device) | |
| for i in range(n_iter): | |
| T[:, meds[:, 0], meds[:, 1]] += 1 | |
| Tneigh = T[:, pt[:, :, 0], pt[:, :, 1]] | |
| Tneigh *= isneigh | |
| T[:, pt[0, :, 0], pt[0, :, 1]] = Tneigh.mean(axis=1) | |
| del meds, isneigh, Tneigh | |
| T = torch.log(1.0 + T) | |
| # gradient positions | |
| grads = T[:, pt[[2, 1, 4, 3], :, 0], pt[[2, 1, 4, 3], :, 1]] | |
| del pt | |
| dy = grads[:, 0] - grads[:, 1] | |
| dx = grads[:, 2] - grads[:, 3] | |
| del grads | |
| mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2) | |
| return mu_torch | |
| def diameters(masks): | |
| _, counts = np.unique(np.int32(masks), return_counts=True) | |
| counts = counts[1:] | |
| md = np.median(counts ** 0.5) | |
| if np.isnan(md): | |
| md = 0 | |
| md /= (np.pi ** 0.5) / 2 | |
| return md, counts ** 0.5 | |
| def masks_to_flows_gpu(masks, device=None): | |
| if device is None: | |
| device = torch.device("cuda") | |
| Ly0, Lx0 = masks.shape | |
| Ly, Lx = Ly0 + 2, Lx0 + 2 | |
| masks_padded = np.zeros((Ly, Lx), np.int64) | |
| masks_padded[1:-1, 1:-1] = masks | |
| # get mask pixel neighbors | |
| y, x = np.nonzero(masks_padded) | |
| neighborsY = np.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), axis=0) | |
| neighborsX = np.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), axis=0) | |
| neighbors = np.stack((neighborsY, neighborsX), axis=-1) | |
| # get mask centers | |
| slices = find_objects(masks) | |
| centers = np.zeros((masks.max(), 2), "int") | |
| for i, si in enumerate(slices): | |
| if si is not None: | |
| sr, sc = si | |
| ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1 | |
| yi, xi = np.nonzero(masks[sr, sc] == (i + 1)) | |
| yi = yi.astype(np.int32) + 1 # add padding | |
| xi = xi.astype(np.int32) + 1 # add padding | |
| ymed = np.median(yi) | |
| xmed = np.median(xi) | |
| imin = np.argmin((xi - xmed) ** 2 + (yi - ymed) ** 2) | |
| xmed = xi[imin] | |
| ymed = yi[imin] | |
| centers[i, 0] = ymed + sr.start | |
| centers[i, 1] = xmed + sc.start | |
| # get neighbor validator (not all neighbors are in same mask) | |
| neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1]] | |
| isneighbor = neighbor_masks == neighbor_masks[0] | |
| ext = np.array( | |
| [[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices] | |
| ) | |
| n_iter = 2 * (ext.sum(axis=1)).max() | |
| # run diffusion | |
| mu = _extend_centers_gpu( | |
| neighbors, centers, isneighbor, Ly, Lx, n_iter=n_iter, device=device | |
| ) | |
| # normalize | |
| mu /= 1e-20 + (mu ** 2).sum(axis=0) ** 0.5 | |
| # put into original image | |
| mu0 = np.zeros((2, Ly0, Lx0)) | |
| mu0[:, y - 1, x - 1] = mu | |
| mu_c = np.zeros_like(mu0) | |
| return mu0, mu_c | |
| def masks_to_flows(masks, use_gpu=False, device=None): | |
| if masks.max() == 0 or (masks != 0).sum() == 1: | |
| # dynamics_logger.warning('empty masks!') | |
| return np.zeros((2, *masks.shape), "float32") | |
| if use_gpu: | |
| if use_gpu and device is None: | |
| device = torch_GPU | |
| elif device is None: | |
| device = torch_CPU | |
| masks_to_flows_device = masks_to_flows_gpu | |
| if masks.ndim == 3: | |
| Lz, Ly, Lx = masks.shape | |
| mu = np.zeros((3, Lz, Ly, Lx), np.float32) | |
| for z in range(Lz): | |
| mu0 = masks_to_flows_device(masks[z], device=device)[0] | |
| mu[[1, 2], z] += mu0 | |
| for y in range(Ly): | |
| mu0 = masks_to_flows_device(masks[:, y], device=device)[0] | |
| mu[[0, 2], :, y] += mu0 | |
| for x in range(Lx): | |
| mu0 = masks_to_flows_device(masks[:, :, x], device=device)[0] | |
| mu[[0, 1], :, :, x] += mu0 | |
| return mu | |
| elif masks.ndim == 2: | |
| mu, mu_c = masks_to_flows_device(masks, device=device) | |
| return mu | |
| else: | |
| raise ValueError("masks_to_flows only takes 2D or 3D arrays") | |
| def steps2D_interp(p, dP, niter, use_gpu=False, device=None): | |
| shape = dP.shape[1:] | |
| if use_gpu: | |
| if device is None: | |
| device = torch_GPU | |
| shape = ( | |
| np.array(shape)[[1, 0]].astype("float") - 1 | |
| ) # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1 | |
| pt = ( | |
| torch.from_numpy(p[[1, 0]].T).float().to(device).unsqueeze(0).unsqueeze(0) | |
| ) # p is n_points by 2, so pt is [1 1 2 n_points] | |
| im = ( | |
| torch.from_numpy(dP[[1, 0]]).float().to(device).unsqueeze(0) | |
| ) # covert flow numpy array to tensor on GPU, add dimension | |
| # normalize pt between 0 and 1, normalize the flow | |
| for k in range(2): | |
| im[:, k, :, :] *= 2.0 / shape[k] | |
| pt[:, :, :, k] /= shape[k] | |
| # normalize to between -1 and 1 | |
| pt = pt * 2 - 1 | |
| # here is where the stepping happens | |
| for t in range(niter): | |
| # align_corners default is False, just added to suppress warning | |
| dPt = grid_sample(im, pt, align_corners=False) | |
| for k in range(2): # clamp the final pixel locations | |
| pt[:, :, :, k] = torch.clamp( | |
| pt[:, :, :, k] + dPt[:, k, :, :], -1.0, 1.0 | |
| ) | |
| # undo the normalization from before, reverse order of operations | |
| pt = (pt + 1) * 0.5 | |
| for k in range(2): | |
| pt[:, :, :, k] *= shape[k] | |
| p = pt[:, :, :, [1, 0]].cpu().numpy().squeeze().T | |
| return p | |
| else: | |
| assert print("ho") | |
| def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None): | |
| shape = np.array(dP.shape[1:]).astype(np.int32) | |
| niter = np.uint32(niter) | |
| p = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij") | |
| p = np.array(p).astype(np.float32) | |
| inds = np.array(np.nonzero(np.abs(dP[0]) > 1e-3)).astype(np.int32).T | |
| if inds.ndim < 2 or inds.shape[0] < 5: | |
| return p, None | |
| if not interp: | |
| assert print("woo") | |
| else: | |
| p_interp = steps2D_interp( | |
| p[:, inds[:, 0], inds[:, 1]], dP, niter, use_gpu=use_gpu, device=device | |
| ) | |
| p[:, inds[:, 0], inds[:, 1]] = p_interp | |
| return p, inds | |
| def flow_error(maski, dP_net, use_gpu=False, device=None): | |
| if dP_net.shape[1:] != maski.shape: | |
| print("ERROR: net flow is not same size as predicted masks") | |
| return | |
| # flows predicted from estimated masks | |
| dP_masks = masks_to_flows(maski, use_gpu=use_gpu, device=device) | |
| # difference between predicted flows vs mask flows | |
| flow_errors = np.zeros(maski.max()) | |
| for i in range(dP_masks.shape[0]): | |
| flow_errors += mean( | |
| (dP_masks[i] - dP_net[i] / 5.0) ** 2, | |
| maski, | |
| index=np.arange(1, maski.max() + 1), | |
| ) | |
| return flow_errors, dP_masks | |
| def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=None): | |
| merrors, _ = flow_error(masks, flows, use_gpu, device) | |
| badi = 1 + (merrors > threshold).nonzero()[0] | |
| masks[np.isin(masks, badi)] = 0 | |
| return masks | |
| def get_masks(p, iscell=None, rpad=20): | |
| pflows = [] | |
| edges = [] | |
| shape0 = p.shape[1:] | |
| dims = len(p) | |
| for i in range(dims): | |
| pflows.append(p[i].flatten().astype("int32")) | |
| edges.append(np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1)) | |
| h, _ = np.histogramdd(tuple(pflows), bins=edges) | |
| hmax = h.copy() | |
| for i in range(dims): | |
| hmax = maximum_filter1d(hmax, 5, axis=i) | |
| seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10)) | |
| Nmax = h[seeds] | |
| isort = np.argsort(Nmax)[::-1] | |
| for s in seeds: | |
| s = s[isort] | |
| pix = list(np.array(seeds).T) | |
| shape = h.shape | |
| if dims == 3: | |
| expand = np.nonzero(np.ones((3, 3, 3))) | |
| else: | |
| expand = np.nonzero(np.ones((3, 3))) | |
| for e in expand: | |
| e = np.expand_dims(e, 1) | |
| for iter in range(5): | |
| for k in range(len(pix)): | |
| if iter == 0: | |
| pix[k] = list(pix[k]) | |
| newpix = [] | |
| iin = [] | |
| for i, e in enumerate(expand): | |
| epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1 | |
| epix = epix.flatten() | |
| iin.append(np.logical_and(epix >= 0, epix < shape[i])) | |
| newpix.append(epix) | |
| iin = np.all(tuple(iin), axis=0) | |
| for p in newpix: | |
| p = p[iin] | |
| newpix = tuple(newpix) | |
| igood = h[newpix] > 2 | |
| for i in range(dims): | |
| pix[k][i] = newpix[i][igood] | |
| if iter == 4: | |
| pix[k] = tuple(pix[k]) | |
| M = np.zeros(h.shape, np.uint32) | |
| for k in range(len(pix)): | |
| M[pix[k]] = 1 + k | |
| for i in range(dims): | |
| pflows[i] = pflows[i] + rpad | |
| M0 = M[tuple(pflows)] | |
| # remove big masks | |
| uniq, counts = fastremap.unique(M0, return_counts=True) | |
| big = np.prod(shape0) * 0.9 | |
| bigc = uniq[counts > big] | |
| if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): | |
| M0 = fastremap.mask(M0, bigc) | |
| fastremap.renumber(M0, in_place=True) # convenient to guarantee non-skipped labels | |
| M0 = np.reshape(M0, shape0) | |
| return M0 | |