File size: 3,095 Bytes
c0c9b2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import numpy as np
from PIL import Image
import torch
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation

MODEL_ID = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"

def make_palette(num_classes: int):
    base = [
        (255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 0, 255),
        (255, 0, 255), (0, 255, 255), (255, 165, 0), (128, 0, 128),
        (255, 192, 203), (191, 255, 0), (0, 128, 128), (165, 42, 42),
        (0, 0, 128), (128, 128, 0), (128, 0, 0), (255, 215, 0),
        (192, 192, 192), (255, 127, 80), (75, 0, 130), (238, 130, 238),
    ]
    return [base[i % len(base)] for i in range(num_classes)]

def colorize(mask: np.ndarray, palette):
    h, w = mask.shape
    out = np.zeros((h, w, 3), dtype=np.uint8)
    for i in range(len(palette)):
        out[mask == i] = palette[i]
    return Image.fromarray(out)

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(device).eval()

id2label = model.config.id2label
NUM_CLASSES = len(id2label)
PALETTE = make_palette(NUM_CLASSES)

def segment(img: Image.Image, alpha: float = 0.5):
    if img is None:
        return None, None
    with torch.no_grad():
        inputs = processor(images=img, return_tensors="pt").to(device)
        outputs = model(**inputs)
        logits = outputs.logits
        up = torch.nn.functional.interpolate(
            logits, size=img.size[::-1], mode="bilinear", align_corners=False
        )
        pred = up.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
    mask_img = colorize(pred, PALETTE)
    overlay = (np.array(img.convert("RGB")) * (1 - alpha) +
               np.array(mask_img) * alpha).astype(np.uint8)
    return mask_img, Image.fromarray(overlay)

def list_examples():
    exdir = "examples"
    if not os.path.isdir(exdir):
        return []
    names = [f for f in os.listdir(exdir)
             if f.lower().endswith((".jpg", ".jpeg", ".png"))]
    return [[os.path.join(exdir, n)] for n in sorted(names)]

title = "Cityscapes Segmentation (SegFormer-b0)"
desc = (
    "Cityscapes(19 classes)둜 ν•™μŠ΅λœ SegFormer-b0 λͺ¨λΈ 데λͺ¨μž…λ‹ˆλ‹€. "
    "λ„μ‹œ/λ„λ‘œ μž₯λ©΄μ—μ„œ μ°¨λŸ‰, λ³΄ν–‰μž, λ„λ‘œ, 건물, ν•˜λŠ˜ 등을 λΆ„ν• ν•©λ‹ˆλ‹€."
)

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"# 🚦 {title}\n{desc}")
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="pil", label="Input Image")
            alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency")
            btn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            out_mask = gr.Image(type="pil", label="Segmentation Mask")
            out_overlay = gr.Image(type="pil", label="Overlay (Image + Mask)")
    ex = list_examples()
    if ex:
        gr.Examples(examples=ex, inputs=[inp], examples_per_page=6, label="Examples")
    btn.click(segment, inputs=[inp, alpha], outputs=[out_mask, out_overlay])

demo.launch()