| """ | |
| Handy functions for flattening nested containers containing numpy | |
| arrays. The main purpose is to make examples and optimizers simpler. | |
| """ | |
| from autograd import make_vjp | |
| from autograd.builtins import type | |
| import autograd.numpy as np | |
| def flatten(value): | |
| """Flattens any nesting of tuples, lists, or dicts, with numpy arrays or | |
| scalars inside. Returns 1D numpy array and an unflatten function. | |
| Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict | |
| keys are sortable.""" | |
| unflatten, flat_value = make_vjp(_flatten)(value) | |
| return flat_value, unflatten | |
| def _flatten(value): | |
| t = type(value) | |
| if t in (list, tuple): | |
| return _concatenate(map(_flatten, value)) | |
| elif t is dict: | |
| return _concatenate(_flatten(value[k]) for k in sorted(value)) | |
| else: | |
| return np.ravel(value) | |
| def _concatenate(lst): | |
| lst = list(lst) | |
| return np.concatenate(lst) if lst else np.array([]) | |
| def flatten_func(func, example): | |
| _ex, unflatten = flatten(example) | |
| _func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0] | |
| return _func, unflatten, _ex | |