File size: 783 Bytes
0def483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
"""
Model loading and device utilities.
"""
import torch
from model import HAT
from config import MODEL_CHECKPOINT, MODEL_CONFIG
def get_device():
"""Get the appropriate device for model inference."""
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model():
"""Load and initialize the HAT model with pre-trained weights."""
device = get_device()
# Initialize model
model = HAT(**MODEL_CONFIG)
# Load the fine-tuned weights
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
# Try different checkpoint formats
state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, device |