Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| from config import eta_1, eta_T, p, T | |
| ''' | |
| Timestamp in our ResShift is 0 - 14 (the scalar value) | |
| ''' | |
| def resshift_schedule(T=T, eta1=eta_1, etaT=eta_T, p=p): | |
| betas = [ ((t-1)/(T-1))**p * (T-1) for t in range(1, T+1) ] | |
| b0 = math.exp((1/(2*(T-1))) * math.log(etaT/eta1)) | |
| eta = [ eta1 * (b0 ** b) for b in betas ] | |
| return torch.tensor(eta) | |