File size: 400 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn
import math
from config import eta_1, eta_T, p, T
'''
Timestamp in our ResShift is 0 - 14 (the scalar value)
'''

def resshift_schedule(T=T, eta1=eta_1, etaT=eta_T, p=p):
    betas = [ ((t-1)/(T-1))**p * (T-1) for t in range(1, T+1) ]
    b0 = math.exp((1/(2*(T-1))) * math.log(etaT/eta1))
    eta = [ eta1 * (b0 ** b) for b in betas ]
    return torch.tensor(eta)