LSTMGlucosePrediction / modeling_lstm.py
njeffrie's picture
Upload 3 files
54259e7 verified
raw
history blame
2.4 kB
import torch
from torch import nn
from typing import Optional, Tuple
try:
from transformers import PreTrainedModel, PretrainedConfig
except Exception as exc: # pragma: no cover
raise RuntimeError("transformers must be installed to use this model") from exc
class LSTMConfig(PretrainedConfig):
model_type = "lstm_timeseries"
def __init__(
self,
input_size: int = 1,
hidden_size: int = 128,
num_layers: int = 2,
dropout: float = 0.1,
len_seq: int = 180,
len_pred: int = 12,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = dropout
self.len_seq = len_seq
self.len_pred = len_pred
class _SimpleLSTM(nn.Module):
def __init__(self, config: LSTMConfig) -> None:
super().__init__()
self.config = config
self.lstm = nn.LSTM(
input_size=config.input_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
batch_first=True,
dropout=config.dropout,
)
self.projection = nn.Linear(config.hidden_size, config.len_pred)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, len_seq, input_size)
lstm_out, _ = self.lstm(x)
last_hidden = lstm_out[:, -1, :]
pred = self.projection(last_hidden)
return pred
class LSTMForTimeSeries(PreTrainedModel):
UPPER = 402
LOWER = 38
SCALE_1 = 5
SCALE_2 = 2
config_class = LSTMConfig
def __init__(self, config: LSTMConfig) -> None:
super().__init__(config)
self.model = _SimpleLSTM(config)
self.model.eval()
def normalize_glucose(self, glucose):
return (glucose - self.LOWER) / (self.UPPER - self.LOWER) * (self.SCALE_1 * self.SCALE_2) - self.SCALE_1
def unnormalize_glucose(self, glucose):
return (glucose + self.SCALE_1) / (self.SCALE_1 * self.SCALE_2) * (self.UPPER - self.LOWER) + self.LOWER
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.tensor(x).reshape(-1, self.config.len_seq, 1).float()
x = self.normalize_glucose(x)
out = self.model(x)
unnormalized_out = self.unnormalize_glucose(out)
return unnormalized_out