| from __future__ import absolute_import, division | |
| import autograd.numpy as np | |
| import scipy.stats | |
| from autograd.extend import primitive, defvjp | |
| from autograd.numpy.numpy_vjps import unbroadcast_f | |
| from autograd.scipy.special import gamma | |
| cdf = primitive(scipy.stats.chi2.cdf) | |
| logpdf = primitive(scipy.stats.chi2.logpdf) | |
| pdf = primitive(scipy.stats.chi2.pdf) | |
| def grad_chi2_logpdf(x, df): | |
| return np.where(df % 1 == 0, (df - x - 2) / (2 * x), 0) | |
| defvjp(cdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * np.power(2., -df/2) * np.exp(-x/2) * np.power(x, df/2 - 1) / gamma(df/2)), argnums=[0]) | |
| defvjp(logpdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * grad_chi2_logpdf(x, df)), argnums=[0]) | |
| defvjp(pdf, lambda ans, x, df: unbroadcast_f(x, lambda g: g * ans * grad_chi2_logpdf(x, df)), argnums=[0]) | |