prac2 / app.py
silverjini0's picture
cityscapes segformer-b0 demo
c0c9b2d
import os
import numpy as np
from PIL import Image
import torch
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
MODEL_ID = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
def make_palette(num_classes: int):
base = [
(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 0, 255),
(255, 0, 255), (0, 255, 255), (255, 165, 0), (128, 0, 128),
(255, 192, 203), (191, 255, 0), (0, 128, 128), (165, 42, 42),
(0, 0, 128), (128, 128, 0), (128, 0, 0), (255, 215, 0),
(192, 192, 192), (255, 127, 80), (75, 0, 130), (238, 130, 238),
]
return [base[i % len(base)] for i in range(num_classes)]
def colorize(mask: np.ndarray, palette):
h, w = mask.shape
out = np.zeros((h, w, 3), dtype=np.uint8)
for i in range(len(palette)):
out[mask == i] = palette[i]
return Image.fromarray(out)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID).to(device).eval()
id2label = model.config.id2label
NUM_CLASSES = len(id2label)
PALETTE = make_palette(NUM_CLASSES)
def segment(img: Image.Image, alpha: float = 0.5):
if img is None:
return None, None
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits
up = torch.nn.functional.interpolate(
logits, size=img.size[::-1], mode="bilinear", align_corners=False
)
pred = up.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
mask_img = colorize(pred, PALETTE)
overlay = (np.array(img.convert("RGB")) * (1 - alpha) +
np.array(mask_img) * alpha).astype(np.uint8)
return mask_img, Image.fromarray(overlay)
def list_examples():
exdir = "examples"
if not os.path.isdir(exdir):
return []
names = [f for f in os.listdir(exdir)
if f.lower().endswith((".jpg", ".jpeg", ".png"))]
return [[os.path.join(exdir, n)] for n in sorted(names)]
title = "Cityscapes Segmentation (SegFormer-b0)"
desc = (
"Cityscapes(19 classes)둜 ν•™μŠ΅λœ SegFormer-b0 λͺ¨λΈ 데λͺ¨μž…λ‹ˆλ‹€. "
"λ„μ‹œ/λ„λ‘œ μž₯λ©΄μ—μ„œ μ°¨λŸ‰, λ³΄ν–‰μž, λ„λ‘œ, 건물, ν•˜λŠ˜ 등을 λΆ„ν• ν•©λ‹ˆλ‹€."
)
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# 🚦 {title}\n{desc}")
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="pil", label="Input Image")
alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency")
btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
out_mask = gr.Image(type="pil", label="Segmentation Mask")
out_overlay = gr.Image(type="pil", label="Overlay (Image + Mask)")
ex = list_examples()
if ex:
gr.Examples(examples=ex, inputs=[inp], examples_per_page=6, label="Examples")
btn.click(segment, inputs=[inp, alpha], outputs=[out_mask, out_overlay])
demo.launch()