openfree's picture
Deploy from GitHub repository
b20c769 verified
from pathlib import Path
from typing import Dict
from .anysat import AnySatWrapper
from .croma import CROMAWrapper
from .decur import DeCurWrapper
from .dofa import DOFAWrapper
from .mmearth import MMEarthWrapper # type: ignore
from .presto import PrestoWrapper, UnWrappedPresto
from .prithvi import PrithviWrapper # type: ignore
from .satlas import SatlasWrapper
from .satmae import SatMAEWrapper
from .softcon import SoftConWrapper
__all__ = [
"CROMAWrapper",
"DOFAWrapper",
"MMEarthWrapper",
"SatlasWrapper",
"SatMAEWrapper",
"SoftConWrapper",
"DeCurWrapper",
"PrestoWrapper",
"AnySatWrapper",
"UnWrappedPresto",
"PrithviWrapper",
]
def construct_model_dict(weights_path: Path, s1_or_s2: str) -> Dict:
model_dict = {
"mmearth_atto": {
"model": MMEarthWrapper,
"args": {
"weights_path": weights_path,
"size": "atto",
},
},
"satmae_pp": {
"model": SatMAEWrapper,
"args": {
"pretrained_path": weights_path / "satmae_pp.pth",
"size": "large",
},
},
"satlas_tiny": {
"model": SatlasWrapper,
"args": {
"weights_path": weights_path,
"size": "tiny",
},
},
"croma_base": {
"model": CROMAWrapper,
"args": {
"weights_path": weights_path,
"size": "base",
"modality": "SAR" if s1_or_s2 == "s1" else "optical",
},
},
"softcon_small": {
"model": SoftConWrapper,
"args": {
"weights_path": weights_path,
"size": "small",
"modality": "SAR" if s1_or_s2 == "s1" else "optical",
},
},
"satmae_base": {
"model": SatMAEWrapper,
"args": {
"pretrained_path": weights_path / "pretrain-vit-base-e199.pth",
"size": "base",
},
},
"dofa_base": {
"model": DOFAWrapper,
"args": {
"weights_path": weights_path,
"size": "base",
},
},
"satlas_base": {
"model": SatlasWrapper,
"args": {
"weights_path": weights_path,
"size": "base",
},
},
"croma_large": {
"model": CROMAWrapper,
"args": {
"weights_path": weights_path,
"size": "large",
"modality": "SAR" if s1_or_s2 == "s1" else "optical",
},
},
"softcon_base": {
"model": SoftConWrapper,
"args": {
"weights_path": weights_path,
"size": "base",
"modality": "SAR" if s1_or_s2 == "s1" else "optical",
},
},
"satmae_large": {
"model": SatMAEWrapper,
"args": {
"pretrained_path": weights_path / "pretrain-vit-large-e199.pth",
"size": "large",
},
},
"dofa_large": {
"model": DOFAWrapper,
"args": {
"weights_path": weights_path,
"size": "large",
},
},
"decur": {
"model": DeCurWrapper,
"args": {
"weights_path": weights_path,
"modality": "SAR" if s1_or_s2 == "s1" else "optical",
},
},
"presto": {
"model": PrestoWrapper if s1_or_s2 in ["s1", "s2"] else UnWrappedPresto,
"args": {},
},
"anysat": {"model": AnySatWrapper, "args": {}},
"prithvi": {"model": PrithviWrapper, "args": {"weights_path": weights_path}},
}
return model_dict
def get_model_config(model_name: str, weights_path: Path, s1_or_s2: str):
return construct_model_dict(weights_path, s1_or_s2)[model_name]
def get_all_model_names():
model_dict = construct_model_dict(Path("."), "s2") # placeholder Path and pooling
return list(model_dict.keys())
BASELINE_MODELS = get_all_model_names()