|
|
import sys
|
|
|
import os
|
|
|
import torch
|
|
|
import torchvision.transforms as T
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
from torch.hub import download_url_to_file
|
|
|
import urllib.parse
|
|
|
|
|
|
|
|
|
|
|
|
dependencies = [
|
|
|
'tomesd',
|
|
|
'omegaconf',
|
|
|
'numpy',
|
|
|
'rich',
|
|
|
'yapf',
|
|
|
'addict',
|
|
|
'tqdm',
|
|
|
'packaging',
|
|
|
'torchvision'
|
|
|
]
|
|
|
|
|
|
|
|
|
model_dir = os.path.join(os.path.dirname(__file__), 'model_without_OpenMMLab')
|
|
|
sys.path.insert(0, model_dir)
|
|
|
|
|
|
|
|
|
from segformer_plusplus.build_model import create_model
|
|
|
from segformer_plusplus.random_benchmark import random_benchmark
|
|
|
|
|
|
|
|
|
sys.path.pop(0)
|
|
|
|
|
|
|
|
|
def _get_local_cache_path(url: str, filename: str) -> str:
|
|
|
"""
|
|
|
Creates the full local path to the checkpoint file in the PyTorch Hub cache.
|
|
|
"""
|
|
|
|
|
|
torch_home = torch.hub._get_torch_home()
|
|
|
|
|
|
|
|
|
checkpoint_dir = os.path.join(torch_home, 'checkpoints')
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
url_path_hash = urllib.parse.quote_plus(url)
|
|
|
|
|
|
|
|
|
local_filename = f"{filename}_{url_path_hash[:10]}.pt"
|
|
|
|
|
|
return os.path.join(checkpoint_dir, local_filename)
|
|
|
|
|
|
|
|
|
|
|
|
def segformer_plusplus(
|
|
|
backbone: str = 'b5',
|
|
|
tome_strategy: str = 'bsm_hq',
|
|
|
out_channels: int = 19,
|
|
|
pretrained: bool = True,
|
|
|
checkpoint_url: str = None,
|
|
|
**kwargs
|
|
|
) -> torch.nn.Module:
|
|
|
"""
|
|
|
Segformer++: Efficient Token-Merging Strategies for High-Resolution Semantic Segmentation.
|
|
|
|
|
|
Loads a SegFormer++ model with the specified backbone and head architecture.
|
|
|
Install requirements via:
|
|
|
pip install tomesd omegaconf numpy rich yapf addict tqdm packaging torchvision
|
|
|
|
|
|
Args:
|
|
|
backbone (str): The backbone type. Selectable from: ['b0', 'b1', 'b2', 'b3', 'b4', 'b5'].
|
|
|
tome_strategy (str): The token merging strategy. Selectable from: ['bsm_hq', 'bsm_fast', 'n2d_2x2'].
|
|
|
out_channels (int): Number of output classes (e.g., 19 for Cityscapes).
|
|
|
pretrained (bool): Whether to load the ImageNet pre-trained weights.
|
|
|
checkpoint_url (str, optional): A URL to a specific checkpoint.
|
|
|
**Important:** The download uses torch.hub.download_url_to_file(),
|
|
|
which may be required for non-direct links.
|
|
|
|
|
|
Returns:
|
|
|
torch.nn.Module: The instantiated SegFormer++ model.
|
|
|
"""
|
|
|
model = create_model(
|
|
|
backbone=backbone,
|
|
|
tome_strategy=tome_strategy,
|
|
|
out_channels=out_channels,
|
|
|
pretrained=pretrained
|
|
|
)
|
|
|
|
|
|
if checkpoint_url:
|
|
|
|
|
|
|
|
|
local_filepath = _get_local_cache_path(
|
|
|
url=checkpoint_url,
|
|
|
filename=f"segformer_plusplus_{backbone}"
|
|
|
)
|
|
|
|
|
|
print(f"Attempting to load checkpoint from {checkpoint_url}...")
|
|
|
|
|
|
if not os.path.exists(local_filepath):
|
|
|
|
|
|
try:
|
|
|
print(f"File not in cache. Downloading to {local_filepath}...")
|
|
|
|
|
|
|
|
|
download_url_to_file(
|
|
|
checkpoint_url,
|
|
|
local_filepath,
|
|
|
progress=True
|
|
|
)
|
|
|
print("Download successful.")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error downloading checkpoint from {checkpoint_url}. Check the URL or use a direct asset link. Error: {e}")
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
try:
|
|
|
state_dict = torch.load(local_filepath, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
print("Checkpoint loaded successfully.")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error loading state dict from file {local_filepath}: {e}")
|
|
|
|
|
|
print("The model was instantiated, but the checkpoint could not be loaded.")
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def data_transforms(
|
|
|
resolution: Tuple[int, int] = (1024, 1024),
|
|
|
mean: List[float] = [0.485, 0.456, 0.406],
|
|
|
std: List[float] = [0.229, 0.224, 0.225],
|
|
|
) -> T.Compose:
|
|
|
"""
|
|
|
Provides the appropriate data transformations for a given dataset.
|
|
|
|
|
|
This function is an entry point to get the necessary preprocessing steps
|
|
|
for images based on typical ImageNet values.
|
|
|
|
|
|
Args:
|
|
|
resolution (Tuple[int, int]): The desired size for the images (width, height).
|
|
|
Defaults to (1024, 1024).
|
|
|
mean (List[float]): The mean values for normalization. Defaults to the
|
|
|
ImageNet means.
|
|
|
std (List[float]): The standard deviations for normalization. Defaults to the
|
|
|
ImageNet standard deviations.
|
|
|
|
|
|
Returns:
|
|
|
torchvision.transforms.Compose: A composition of transforms
|
|
|
that can be applied to input images.
|
|
|
|
|
|
Example:
|
|
|
>>> # Load transforms with default parameters
|
|
|
>>> transform = torch.hub.load('user/repo_name', 'data_transforms')
|
|
|
>>>
|
|
|
>>> # Load transforms with resize to custom image resolution and default normalization
|
|
|
>>> transform_small = torch.hub.load('user/repo_name', 'data_transforms', resolution=(512, 512))
|
|
|
"""
|
|
|
transform = T.Compose([
|
|
|
T.Resize(resolution),
|
|
|
T.ToTensor(),
|
|
|
T.Normalize(mean=mean, std=std)
|
|
|
])
|
|
|
return transform
|
|
|
|
|
|
|
|
|
|
|
|
def random_benchmark_entrypoint(**kwargs):
|
|
|
"""
|
|
|
Runs a random benchmark for SegFormer++.
|
|
|
"""
|
|
|
return random_benchmark(**kwargs) |