| from __future__ import absolute_import | |
| import autograd.numpy as np | |
| import scipy.stats | |
| from autograd.extend import primitive, defvjp | |
| from autograd.numpy.numpy_vjps import unbroadcast_f | |
| from autograd.scipy.special import gamma, psi | |
| cdf = primitive(scipy.stats.gamma.cdf) | |
| logpdf = primitive(scipy.stats.gamma.logpdf) | |
| pdf = primitive(scipy.stats.gamma.pdf) | |
| def grad_gamma_logpdf_arg0(x, a): | |
| return (a - x - 1) / x | |
| def grad_gamma_logpdf_arg1(x, a): | |
| return np.log(x) - psi(a) | |
| defvjp(cdf, lambda ans, x, a: unbroadcast_f(x, lambda g: g * np.exp(-x) * np.power(x, a-1) / gamma(a)), argnums=[0]) | |
| defvjp(logpdf, | |
| lambda ans, x, a: unbroadcast_f(x, lambda g: g * grad_gamma_logpdf_arg0(x, a)), | |
| lambda ans, x, a: unbroadcast_f(a, lambda g: g * grad_gamma_logpdf_arg1(x, a))) | |
| defvjp(pdf, | |
| lambda ans, x, a: unbroadcast_f(x, lambda g: g * ans * grad_gamma_logpdf_arg0(x, a)), | |
| lambda ans, x, a: unbroadcast_f(a, lambda g: g * ans * grad_gamma_logpdf_arg1(x, a))) | |