| from __future__ import absolute_import | |
| import autograd.numpy as np | |
| import scipy.stats | |
| from autograd.extend import primitive, defvjp | |
| from autograd.numpy.numpy_vjps import unbroadcast_f | |
| cdf = primitive(scipy.stats.poisson.cdf) | |
| logpmf = primitive(scipy.stats.poisson.logpmf) | |
| pmf = primitive(scipy.stats.poisson.pmf) | |
| def grad_poisson_logpmf(k, mu): | |
| return np.where(k % 1 == 0, k / mu - 1, 0) | |
| defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1]) | |
| defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1]) | |
| defvjp(pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1]) | |