| """Gradients of the normal distribution.""" | |
| from __future__ import absolute_import | |
| import scipy.stats | |
| import autograd.numpy as anp | |
| from autograd.extend import primitive, defvjp | |
| from autograd.numpy.numpy_vjps import unbroadcast_f | |
| pdf = primitive(scipy.stats.norm.pdf) | |
| cdf = primitive(scipy.stats.norm.cdf) | |
| sf = primitive(scipy.stats.norm.sf) | |
| logpdf = primitive(scipy.stats.norm.logpdf) | |
| logcdf = primitive(scipy.stats.norm.logcdf) | |
| logsf = primitive(scipy.stats.norm.logsf) | |
| defvjp(pdf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: -g * ans * (x - loc) / scale**2), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g: g * ans * (x - loc) / scale**2), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g: g * ans * (((x - loc)/scale)**2 - 1.0)/scale)) | |
| defvjp(cdf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: g * pdf(x, loc, scale)) , | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g: -g * pdf(x, loc, scale)), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g: -g * pdf(x, loc, scale)*(x-loc)/scale)) | |
| defvjp(logpdf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: -g * (x - loc) / scale**2), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g: g * (x - loc) / scale**2), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g: g * (-1.0/scale + (x - loc)**2/scale**3))) | |
| defvjp(logcdf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g:-g * anp.exp(logpdf(x, loc, scale) - logcdf(x, loc, scale))*(x-loc)/scale)) | |
| defvjp(logsf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: -g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale))), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g: g * anp.exp(logpdf(x, loc, scale) - logsf(x, loc, scale)) * (x - loc) / scale)) | |
| defvjp(sf, | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(x, lambda g: -g * pdf(x, loc, scale)) , | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(loc, lambda g: g * pdf(x, loc, scale)), | |
| lambda ans, x, loc=0.0, scale=1.0: | |
| unbroadcast_f(scale, lambda g: g * pdf(x, loc, scale)*(x-loc)/scale)) | |