Spaces:
Sleeping
Sleeping
File size: 400 Bytes
3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
import torch.nn as nn
import math
from config import eta_1, eta_T, p, T
'''
Timestamp in our ResShift is 0 - 14 (the scalar value)
'''
def resshift_schedule(T=T, eta1=eta_1, etaT=eta_T, p=p):
betas = [ ((t-1)/(T-1))**p * (T-1) for t in range(1, T+1) ]
b0 = math.exp((1/(2*(T-1))) * math.log(etaT/eta1))
eta = [ eta1 * (b0 ** b) for b in betas ]
return torch.tensor(eta)
|