from __future__ import absolute_import import types import warnings from autograd.extend import primitive, notrace_primitive import numpy as _np import autograd.builtins as builtins from numpy.core.einsumfunc import _parse_einsum_input notrace_functions = [ _np.ndim, _np.shape, _np.iscomplexobj, _np.result_type ] def wrap_intdtype(cls): class IntdtypeSubclass(cls): __new__ = notrace_primitive(cls.__new__) return IntdtypeSubclass def wrap_namespace(old, new): unchanged_types = {float, int, type(None), type} int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer} for name, obj in old.items(): if obj in notrace_functions: new[name] = notrace_primitive(obj) elif callable(obj) and type(obj) is not type: new[name] = primitive(obj) elif type(obj) is type and obj in int_types: new[name] = wrap_intdtype(obj) elif type(obj) in unchanged_types: new[name] = obj wrap_namespace(_np.__dict__, globals()) # ----- Special treatment of list-input functions ----- @primitive def concatenate_args(axis, *args): return _np.concatenate(args, axis).view(ndarray) concatenate = lambda arr_list, axis=0 : concatenate_args(axis, *arr_list) vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0) def hstack(tup): arrs = [atleast_1d(_m) for _m in tup] if arrs[0].ndim == 1: return concatenate(arrs, 0) return concatenate(arrs, 1) def column_stack(tup): arrays = [] for v in tup: arr = array(v) if arr.ndim < 2: arr = array(arr, ndmin=2).T arrays.append(arr) return concatenate(arrays, 1) def array(A, *args, **kwargs): t = builtins.type(A) if t in (list, tuple): return array_from_args(args, kwargs, *map(array, A)) else: return _array_from_scalar_or_array(args, kwargs, A) def wrap_if_boxes_inside(raw_array, slow_op_name=None): if raw_array.dtype is _np.dtype('O'): if slow_op_name: warnings.warn("{0} is slow for array inputs. " "np.concatenate() is faster.".format(slow_op_name)) return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape) else: return raw_array @primitive def _array_from_scalar_or_array(array_args, array_kwargs, scalar): return _np.array(scalar, *array_args, **array_kwargs) @primitive def array_from_args(array_args, array_kwargs, *args): return _np.array(args, *array_args, **array_kwargs) def select(condlist, choicelist, default=0): raw_array = _np.select(list(condlist), list(choicelist), default=default) return array(list(raw_array.ravel())).reshape(raw_array.shape) def stack(arrays, axis=0): # this code is basically copied from numpy/core/shape_base.py's stack # we need it here because we want to re-implement stack in terms of the # primitives defined in this file arrays = [array(arr) for arr in arrays] if not arrays: raise ValueError('need at least one array to stack') shapes = set(arr.shape for arr in arrays) if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') result_ndim = arrays[0].ndim + 1 if not -result_ndim <= axis < result_ndim: raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)) if axis < 0: axis += result_ndim sl = (slice(None),) * axis + (None,) return concatenate([arr[sl] for arr in arrays], axis=axis) def append(arr, values, axis=None): # this code is basically copied from numpy/lib/function_base.py's append arr = array(arr) if axis is None: if ndim(arr) != 1: arr = ravel(arr) values = ravel(array(values)) axis = ndim(arr) - 1 return concatenate((arr, values), axis=axis) # ----- Enable functions called using [] ---- class r_class(): def __getitem__(self, args): raw_array = _np.r_[args] return wrap_if_boxes_inside(raw_array, slow_op_name = "r_") r_ = r_class() class c_class(): def __getitem__(self, args): raw_array = _np.c_[args] return wrap_if_boxes_inside(raw_array, slow_op_name = "c_") c_ = c_class() # ----- misc ----- @primitive def make_diagonal(D, offset=0, axis1=0, axis2=1): # Numpy doesn't offer a complement to np.diagonal: a function to create new # diagonal arrays with extra dimensions. We need such a function for the # gradient of np.diagonal and it's also quite handy to have. So here it is. if not (offset==0 and axis1==-1 and axis2==-2): raise NotImplementedError("Currently make_diagonal only supports offset=0, axis1=-1, axis2=-2") # We use a trick: calling np.diagonal returns a view on the original array, # so we can modify it in-place. (only valid for numpy version >= 1.10.) new_array = _np.zeros(D.shape + (D.shape[-1],)) new_array_diag = _np.diagonal(new_array, offset=0, axis1=-1, axis2=-2) new_array_diag.flags.writeable = True new_array_diag[:] = D return new_array @notrace_primitive def metadata(A): return _np.shape(A), _np.ndim(A), _np.result_type(A), _np.iscomplexobj(A) @notrace_primitive def parse_einsum_input(*args): return _parse_einsum_input(args) @primitive def _astype(A, dtype, order='K', casting='unsafe', subok=True, copy=True): return A.astype(dtype, order, casting, subok, copy)