18256559666a70d7f16bb789b379220578acb9d0f204be811e183cc0a99d467b
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h +16 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h +70 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h +481 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h +538 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h +176 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h +246 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h +31 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h +31 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h +370 -0
- lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h +169 -0
- lib/python3.11/site-packages/mlx/include/mlx/device.h +29 -0
- lib/python3.11/site-packages/mlx/include/mlx/dtype.h +105 -0
- lib/python3.11/site-packages/mlx/include/mlx/fft.h +151 -0
- lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h +23 -0
- lib/python3.11/site-packages/mlx/include/mlx/io/load.h +114 -0
- lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h +32 -0
- lib/python3.11/site-packages/mlx/include/mlx/linalg.h +63 -0
- lib/python3.11/site-packages/mlx/include/mlx/mlx.h +14 -0
- lib/python3.11/site-packages/mlx/include/mlx/ops.h +1094 -0
- lib/python3.11/site-packages/mlx/include/mlx/primitives.h +1636 -0
- lib/python3.11/site-packages/mlx/include/mlx/random.h +193 -0
- lib/python3.11/site-packages/mlx/include/mlx/scheduler.h +173 -0
- lib/python3.11/site-packages/mlx/include/mlx/stream.h +32 -0
- lib/python3.11/site-packages/mlx/include/mlx/transforms.h +187 -0
- lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h +17 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h +187 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/complex.h +77 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h +234 -0
- lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h +56 -0
- lib/python3.11/site-packages/mlx/include/mlx/utils.h +44 -0
- lib/python3.11/site-packages/mlx/lib/libmlx.dylib +3 -0
- lib/python3.11/site-packages/mlx/lib/mlx.metallib +3 -0
- lib/python3.11/site-packages/mlx/nn/__init__.py +5 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__init__.py +63 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
|
| 37 |
lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text
|
| 37 |
lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
lib/python3.11/site-packages/mlx/lib/libmlx.dylib filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
lib/python3.11/site-packages/mlx/lib/mlx.metallib filter=lfs diff=lfs merge=lfs -text
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#ifdef __METAL__
|
| 6 |
+
#define MTL_CONST constant
|
| 7 |
+
#else
|
| 8 |
+
#define MTL_CONST
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
| 12 |
+
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
| 13 |
+
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
| 14 |
+
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
| 15 |
+
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
| 16 |
+
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_math>
|
| 6 |
+
|
| 7 |
+
/*
|
| 8 |
+
* Approximation to the error function.
|
| 9 |
+
* Based on code from:
|
| 10 |
+
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
|
| 11 |
+
*/
|
| 12 |
+
float erf(float a) {
|
| 13 |
+
float r, s, t, u;
|
| 14 |
+
t = metal::abs(a);
|
| 15 |
+
s = a * a;
|
| 16 |
+
if (t > 0.927734375f) {
|
| 17 |
+
// maximum error 0.99527 ulp
|
| 18 |
+
r = metal::fma(
|
| 19 |
+
-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
|
| 20 |
+
u = metal::fma(
|
| 21 |
+
-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
|
| 22 |
+
r = metal::fma(r, s, u);
|
| 23 |
+
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
|
| 24 |
+
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
|
| 25 |
+
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
|
| 26 |
+
r = metal::fma(r, t, -t);
|
| 27 |
+
// TODO, replace with expm1 when implemented
|
| 28 |
+
r = 1.0f - metal::exp(r);
|
| 29 |
+
r = metal::copysign(r, a);
|
| 30 |
+
} else {
|
| 31 |
+
// maximum error 0.98929 ulp
|
| 32 |
+
r = -5.96761703e-4f; // -0x1.38e000p-11
|
| 33 |
+
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
|
| 34 |
+
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
|
| 35 |
+
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
|
| 36 |
+
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
|
| 37 |
+
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
|
| 38 |
+
r = metal::fma(r, a, a);
|
| 39 |
+
}
|
| 40 |
+
return r;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
float erfinv(float a) {
|
| 44 |
+
auto t = metal::fma(a, 0.0f - a, 1.0f);
|
| 45 |
+
t = metal::log(t);
|
| 46 |
+
float p;
|
| 47 |
+
if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
| 48 |
+
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
| 49 |
+
p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
| 50 |
+
p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
| 51 |
+
p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
| 52 |
+
p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
| 53 |
+
p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
| 54 |
+
p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
| 55 |
+
p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
| 56 |
+
p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
| 57 |
+
} else { // maximum ulp error = 2.35002
|
| 58 |
+
p = 5.43877832e-9f; // 0x1.75c000p-28
|
| 59 |
+
p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
| 60 |
+
p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
| 61 |
+
p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
| 62 |
+
p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
| 63 |
+
p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
| 64 |
+
p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
| 65 |
+
p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
| 66 |
+
p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
| 67 |
+
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
| 68 |
+
}
|
| 69 |
+
return a * p;
|
| 70 |
+
}
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
#include <metal_simdgroup_matrix>
|
| 7 |
+
#include <metal_stdlib>
|
| 8 |
+
|
| 9 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
| 10 |
+
#include "mlx/backend/metal/kernels/conv_params.h"
|
| 11 |
+
|
| 12 |
+
#define MLX_MTL_CONST static constant constexpr const
|
| 13 |
+
|
| 14 |
+
using namespace metal;
|
| 15 |
+
|
| 16 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 17 |
+
// Loading helper
|
| 18 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 19 |
+
|
| 20 |
+
template <
|
| 21 |
+
typename T,
|
| 22 |
+
int BM,
|
| 23 |
+
int BN,
|
| 24 |
+
int BK,
|
| 25 |
+
int vec_size,
|
| 26 |
+
int tgp_size,
|
| 27 |
+
int tgp_padding = 0>
|
| 28 |
+
struct Conv2DInputBlockLoader {
|
| 29 |
+
// Destination dimensions
|
| 30 |
+
MLX_MTL_CONST int dst_fd = BM;
|
| 31 |
+
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
| 32 |
+
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
| 33 |
+
|
| 34 |
+
// Stride along block row within the block
|
| 35 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
| 36 |
+
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
| 37 |
+
|
| 38 |
+
// Thread location indices
|
| 39 |
+
const short thread_idx;
|
| 40 |
+
const short bi;
|
| 41 |
+
const short bj;
|
| 42 |
+
|
| 43 |
+
// threadgroup and device memory
|
| 44 |
+
threadgroup T* dst;
|
| 45 |
+
const device T* src;
|
| 46 |
+
|
| 47 |
+
const constant MLXConvParams<2>& params;
|
| 48 |
+
|
| 49 |
+
int weight_h;
|
| 50 |
+
int weight_w;
|
| 51 |
+
|
| 52 |
+
int offsets_n[n_rows];
|
| 53 |
+
int offsets_oh[n_rows];
|
| 54 |
+
int offsets_ow[n_rows];
|
| 55 |
+
|
| 56 |
+
/* Constructor */
|
| 57 |
+
METAL_FUNC Conv2DInputBlockLoader(
|
| 58 |
+
const device T* src_,
|
| 59 |
+
threadgroup T* dst_,
|
| 60 |
+
const constant MLXConvParams<2>& params_,
|
| 61 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 62 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
| 63 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 64 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
| 65 |
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
| 66 |
+
bi(thread_idx / n_vecs),
|
| 67 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
| 68 |
+
dst(dst_ + bi * dst_ld + bj),
|
| 69 |
+
src(src_ + bj),
|
| 70 |
+
params(params_),
|
| 71 |
+
weight_h(0),
|
| 72 |
+
weight_w(0) {
|
| 73 |
+
int out_n_pixels = params.oS[0] * params.oS[1];
|
| 74 |
+
|
| 75 |
+
for (int i = 0; i < n_rows; ++i) {
|
| 76 |
+
int offset_nhw = tid.y * BM + bi + i * bstride;
|
| 77 |
+
offsets_n[i] = offset_nhw / out_n_pixels;
|
| 78 |
+
int hw = offset_nhw % out_n_pixels;
|
| 79 |
+
offsets_oh[i] = hw / params.oS[1];
|
| 80 |
+
offsets_ow[i] = hw % params.oS[1];
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
(void)lid;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
| 87 |
+
METAL_FUNC void load_unsafe() const {
|
| 88 |
+
#pragma clang loop unroll(full)
|
| 89 |
+
for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) {
|
| 90 |
+
int n = offsets_n[i];
|
| 91 |
+
int oh = offsets_oh[i];
|
| 92 |
+
int ow = offsets_ow[i];
|
| 93 |
+
|
| 94 |
+
int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0];
|
| 95 |
+
int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1];
|
| 96 |
+
|
| 97 |
+
// Read from input if in bounds
|
| 98 |
+
if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) {
|
| 99 |
+
const device T* curr_src = src + n * params.in_strides[0] +
|
| 100 |
+
ih * params.in_strides[1] + iw * params.in_strides[2];
|
| 101 |
+
|
| 102 |
+
#pragma clang loop unroll(full)
|
| 103 |
+
for (short j = 0; j < vec_size; ++j) {
|
| 104 |
+
dst[is * dst_ld + j] = curr_src[j];
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Zero pad otherwise
|
| 109 |
+
else {
|
| 110 |
+
#pragma clang loop unroll(full)
|
| 111 |
+
for (short j = 0; j < vec_size; ++j) {
|
| 112 |
+
dst[is * dst_ld + j] = T(0);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/* Iteration helper */
|
| 119 |
+
METAL_FUNC void next() {
|
| 120 |
+
if (++weight_w < params.wS[1]) {
|
| 121 |
+
return;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
weight_w = 0;
|
| 125 |
+
|
| 126 |
+
if (++weight_h < params.wS[0]) {
|
| 127 |
+
return;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
weight_h = 0;
|
| 131 |
+
|
| 132 |
+
src += BK;
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
template <
|
| 137 |
+
typename T,
|
| 138 |
+
int BM,
|
| 139 |
+
int BN,
|
| 140 |
+
int BK,
|
| 141 |
+
int vec_size,
|
| 142 |
+
int tgp_size,
|
| 143 |
+
int tgp_padding = 0>
|
| 144 |
+
struct Conv2DWeightBlockLoader {
|
| 145 |
+
// Destination dimensions
|
| 146 |
+
MLX_MTL_CONST int dst_fd = BN;
|
| 147 |
+
MLX_MTL_CONST int dst_ld = BK + tgp_padding;
|
| 148 |
+
MLX_MTL_CONST int n_vecs = BK / vec_size;
|
| 149 |
+
|
| 150 |
+
// Stride along block row within the block
|
| 151 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
| 152 |
+
MLX_MTL_CONST int n_rows = dst_fd / bstride;
|
| 153 |
+
|
| 154 |
+
// Leading dimension for src
|
| 155 |
+
const int src_ld;
|
| 156 |
+
|
| 157 |
+
// Thread location indices
|
| 158 |
+
const short thread_idx;
|
| 159 |
+
const short bi;
|
| 160 |
+
const short bj;
|
| 161 |
+
|
| 162 |
+
// threadgroup and device memory
|
| 163 |
+
threadgroup T* dst;
|
| 164 |
+
const device T* src;
|
| 165 |
+
|
| 166 |
+
const constant MLXConvParams<2>& params;
|
| 167 |
+
|
| 168 |
+
int weight_h;
|
| 169 |
+
int weight_w;
|
| 170 |
+
|
| 171 |
+
/* Constructor */
|
| 172 |
+
METAL_FUNC Conv2DWeightBlockLoader(
|
| 173 |
+
const device T* src_,
|
| 174 |
+
threadgroup T* dst_,
|
| 175 |
+
const constant MLXConvParams<2>& params_,
|
| 176 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 177 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
| 178 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 179 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
| 180 |
+
: src_ld(params_.wt_strides[0]),
|
| 181 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
| 182 |
+
bi(thread_idx / n_vecs),
|
| 183 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
| 184 |
+
dst(dst_ + bi * dst_ld + bj),
|
| 185 |
+
src(src_ + bi * src_ld + bj),
|
| 186 |
+
params(params_),
|
| 187 |
+
weight_h(0),
|
| 188 |
+
weight_w(0) {
|
| 189 |
+
(void)lid;
|
| 190 |
+
(void)tid;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
| 194 |
+
METAL_FUNC void load_unsafe() const {
|
| 195 |
+
const device T* curr_src =
|
| 196 |
+
src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2];
|
| 197 |
+
#pragma clang loop unroll(full)
|
| 198 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
| 199 |
+
#pragma clang loop unroll(full)
|
| 200 |
+
for (short j = 0; j < vec_size; j++) {
|
| 201 |
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/* Iteration helper */
|
| 207 |
+
METAL_FUNC void next() {
|
| 208 |
+
if (++weight_w < params.wS[1]) {
|
| 209 |
+
return;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
weight_w = 0;
|
| 213 |
+
|
| 214 |
+
if (++weight_h < params.wS[0]) {
|
| 215 |
+
return;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
weight_h = 0;
|
| 219 |
+
|
| 220 |
+
src += BK;
|
| 221 |
+
}
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 225 |
+
// Transforms
|
| 226 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 227 |
+
|
| 228 |
+
template <typename OutT, typename InT>
|
| 229 |
+
struct TransformNone {
|
| 230 |
+
static METAL_FUNC OutT apply(InT x) {
|
| 231 |
+
return static_cast<OutT>(x);
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
|
| 235 |
+
template <typename T>
|
| 236 |
+
struct AccumHelper {
|
| 237 |
+
typedef float accum_type;
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 241 |
+
// MMA helper
|
| 242 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 243 |
+
|
| 244 |
+
template <
|
| 245 |
+
typename T,
|
| 246 |
+
int BM,
|
| 247 |
+
int BN,
|
| 248 |
+
int BK,
|
| 249 |
+
int WM,
|
| 250 |
+
int WN,
|
| 251 |
+
bool transpose_a,
|
| 252 |
+
bool transpose_b,
|
| 253 |
+
int tgp_padding_a = 0,
|
| 254 |
+
int tgp_padding_b = 0,
|
| 255 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
| 256 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
| 257 |
+
struct Conv2DBlockMMA {
|
| 258 |
+
// Warp tile size along M
|
| 259 |
+
MLX_MTL_CONST int TM = BM / (WM * 8);
|
| 260 |
+
// Warp tile size along N
|
| 261 |
+
MLX_MTL_CONST int TN = BN / (WN * 8);
|
| 262 |
+
|
| 263 |
+
// Warp tile simdgroup matrix strides along M
|
| 264 |
+
MLX_MTL_CONST int TM_stride = 8 * WM;
|
| 265 |
+
// Warp tile simdgroup matrix strides along M
|
| 266 |
+
MLX_MTL_CONST int TN_stride = 8 * WN;
|
| 267 |
+
|
| 268 |
+
// Leading dimensions of threadgroup A, B blocks
|
| 269 |
+
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
| 270 |
+
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
| 271 |
+
|
| 272 |
+
// Strides of A, B along reduction axis
|
| 273 |
+
MLX_MTL_CONST short simd_stride_a =
|
| 274 |
+
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
| 275 |
+
MLX_MTL_CONST short simd_stride_b =
|
| 276 |
+
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
| 277 |
+
|
| 278 |
+
// Jump between elements
|
| 279 |
+
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
| 280 |
+
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
| 281 |
+
|
| 282 |
+
// Offsets within threadgroup
|
| 283 |
+
const int tm;
|
| 284 |
+
const int tn;
|
| 285 |
+
|
| 286 |
+
// Simdgroup matrices
|
| 287 |
+
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
| 288 |
+
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
| 289 |
+
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
| 290 |
+
simdgroup_matrix<AccumType, 8, 8>(0)};
|
| 291 |
+
|
| 292 |
+
short sm;
|
| 293 |
+
short sn;
|
| 294 |
+
|
| 295 |
+
/* Constructor */
|
| 296 |
+
METAL_FUNC Conv2DBlockMMA(
|
| 297 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 298 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
| 299 |
+
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
| 300 |
+
short qid = simd_lane_id / 4;
|
| 301 |
+
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
| 302 |
+
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
| 306 |
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
| 307 |
+
// Iterate over BK in blocks of 8
|
| 308 |
+
#pragma clang loop unroll(full)
|
| 309 |
+
for (short kk = 0; kk < BK; kk += 8) {
|
| 310 |
+
short2 offset_a =
|
| 311 |
+
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
| 312 |
+
short2 offset_b =
|
| 313 |
+
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
| 314 |
+
|
| 315 |
+
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
| 316 |
+
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
| 317 |
+
|
| 318 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 319 |
+
// Load elements from threadgroup A as simdgroup matrices
|
| 320 |
+
#pragma clang loop unroll(full)
|
| 321 |
+
for (short i = 0; i < TM; i++) {
|
| 322 |
+
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
| 323 |
+
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
| 324 |
+
As__ += simd_stride_a;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 328 |
+
// Load elements from threadgroup B as simdgroup matrices
|
| 329 |
+
#pragma clang loop unroll(full)
|
| 330 |
+
for (short j = 0; j < TN; j++) {
|
| 331 |
+
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
| 332 |
+
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
| 333 |
+
Bs__ += simd_stride_b;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 337 |
+
// Multiply and accumulate into result simdgroup matrices
|
| 338 |
+
#pragma clang loop unroll(full)
|
| 339 |
+
for (short i = 0; i < TM; i++) {
|
| 340 |
+
#pragma clang loop unroll(full)
|
| 341 |
+
for (short j = 0; j < TN; j++) {
|
| 342 |
+
simdgroup_multiply_accumulate(
|
| 343 |
+
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
/* Store results from simdgroup_matrix results into device memory */
|
| 350 |
+
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
| 351 |
+
#pragma clang loop unroll(full)
|
| 352 |
+
for (int i = 0; i < TM; i++) {
|
| 353 |
+
#pragma clang loop unroll(full)
|
| 354 |
+
for (int j = 0; j < TN; j++) {
|
| 355 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
| 356 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
| 357 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
| 358 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
METAL_FUNC void
|
| 364 |
+
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
| 365 |
+
#pragma clang loop unroll(full)
|
| 366 |
+
for (int i = 0; i < TM; i++) {
|
| 367 |
+
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
| 368 |
+
#pragma clang loop unroll(full)
|
| 369 |
+
for (int j = 0; j < TN; j++) {
|
| 370 |
+
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
| 371 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
| 372 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
| 376 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
| 377 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 386 |
+
// GEMM kernels
|
| 387 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 388 |
+
|
| 389 |
+
template <
|
| 390 |
+
typename T,
|
| 391 |
+
int BM,
|
| 392 |
+
int BN,
|
| 393 |
+
int BK,
|
| 394 |
+
int WM,
|
| 395 |
+
int WN,
|
| 396 |
+
bool transpose_a,
|
| 397 |
+
bool transpose_b,
|
| 398 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
| 399 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
| 400 |
+
struct Conv2DImplicitGEMMKernel {
|
| 401 |
+
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
| 402 |
+
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
| 403 |
+
MLX_MTL_CONST short tgp_mem_size_a =
|
| 404 |
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
| 405 |
+
MLX_MTL_CONST short tgp_mem_size_b =
|
| 406 |
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
| 407 |
+
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
| 408 |
+
|
| 409 |
+
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
| 410 |
+
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
| 411 |
+
|
| 412 |
+
using loader_a_t =
|
| 413 |
+
Conv2DInputBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_a>;
|
| 414 |
+
using loader_b_t =
|
| 415 |
+
Conv2DWeightBlockLoader<T, BM, BN, BK, vec_size, tgp_size, tgp_padding_b>;
|
| 416 |
+
using mma_t = Conv2DBlockMMA<
|
| 417 |
+
T,
|
| 418 |
+
BM,
|
| 419 |
+
BN,
|
| 420 |
+
BK,
|
| 421 |
+
WM,
|
| 422 |
+
WN,
|
| 423 |
+
transpose_a,
|
| 424 |
+
transpose_b,
|
| 425 |
+
tgp_padding_a,
|
| 426 |
+
tgp_padding_b,
|
| 427 |
+
AccumType,
|
| 428 |
+
Epilogue>;
|
| 429 |
+
|
| 430 |
+
/* Main kernel function */
|
| 431 |
+
static METAL_FUNC void run(
|
| 432 |
+
const device T* A [[buffer(0)]],
|
| 433 |
+
const device T* B [[buffer(1)]],
|
| 434 |
+
device T* C [[buffer(2)]],
|
| 435 |
+
const constant MLXConvParams<2>& params [[buffer(3)]],
|
| 436 |
+
threadgroup T* tgp_memory [[threadgroup(0)]],
|
| 437 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 438 |
+
uint3 lid [[thread_position_in_threadgroup]],
|
| 439 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
| 440 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 441 |
+
const int c_row = tid.y * BM;
|
| 442 |
+
const int c_col = tid.x * BN;
|
| 443 |
+
const int K = params.wt_strides[0];
|
| 444 |
+
const int N = params.O;
|
| 445 |
+
|
| 446 |
+
B += c_col * K;
|
| 447 |
+
C += c_row * N + c_col;
|
| 448 |
+
|
| 449 |
+
// Prepare threadgroup memory for loading
|
| 450 |
+
threadgroup T* As = tgp_memory;
|
| 451 |
+
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
| 452 |
+
|
| 453 |
+
// Prepare threadgroup loading operations
|
| 454 |
+
loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid);
|
| 455 |
+
loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid);
|
| 456 |
+
|
| 457 |
+
// Prepare threadgroup mma operation
|
| 458 |
+
mma_t mma_op(simd_gid, simd_lid);
|
| 459 |
+
|
| 460 |
+
for (int k = 0; k < K; k += BK) {
|
| 461 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 462 |
+
// Load elements into threadgroup
|
| 463 |
+
loader_a.load_unsafe();
|
| 464 |
+
loader_b.load_unsafe();
|
| 465 |
+
|
| 466 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 467 |
+
|
| 468 |
+
// Multiply and accumulate threadgroup elements
|
| 469 |
+
mma_op.mma(As, Bs);
|
| 470 |
+
|
| 471 |
+
// Prepare for next iteration
|
| 472 |
+
loader_a.next();
|
| 473 |
+
loader_b.next();
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 477 |
+
|
| 478 |
+
// Store results to device memory
|
| 479 |
+
mma_op.store_result(C, N);
|
| 480 |
+
}
|
| 481 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
#include <metal_simdgroup_matrix>
|
| 7 |
+
#include <metal_stdlib>
|
| 8 |
+
|
| 9 |
+
#define MLX_MTL_CONST static constant constexpr const
|
| 10 |
+
|
| 11 |
+
using namespace metal;
|
| 12 |
+
|
| 13 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 14 |
+
// Loading helper
|
| 15 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
|
| 17 |
+
template <
|
| 18 |
+
typename T,
|
| 19 |
+
int BROWS,
|
| 20 |
+
int BCOLS,
|
| 21 |
+
int BK,
|
| 22 |
+
int vec_size,
|
| 23 |
+
int tgp_size,
|
| 24 |
+
bool transpose,
|
| 25 |
+
bool ldK,
|
| 26 |
+
int tgp_padding = 0>
|
| 27 |
+
struct BlockLoader {
|
| 28 |
+
// Destination dimensions
|
| 29 |
+
MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS;
|
| 30 |
+
MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding;
|
| 31 |
+
MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size;
|
| 32 |
+
|
| 33 |
+
// Stride along block row within the block
|
| 34 |
+
MLX_MTL_CONST int bstride = tgp_size / n_vecs;
|
| 35 |
+
|
| 36 |
+
// Leading dimension for src
|
| 37 |
+
const int src_ld;
|
| 38 |
+
// Stride along reduction axis between blocks
|
| 39 |
+
const int tstride;
|
| 40 |
+
|
| 41 |
+
// Thread location indices
|
| 42 |
+
const short thread_idx;
|
| 43 |
+
const short bi;
|
| 44 |
+
const short bj;
|
| 45 |
+
|
| 46 |
+
// threadgroup and device memory
|
| 47 |
+
threadgroup T* dst;
|
| 48 |
+
const device T* src;
|
| 49 |
+
|
| 50 |
+
/* Constructor */
|
| 51 |
+
METAL_FUNC BlockLoader(
|
| 52 |
+
const device T* src_,
|
| 53 |
+
const int src_ld_,
|
| 54 |
+
threadgroup T* dst_,
|
| 55 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 56 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
| 57 |
+
: src_ld(src_ld_),
|
| 58 |
+
tstride(
|
| 59 |
+
BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))),
|
| 60 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
| 61 |
+
bi(thread_idx / n_vecs),
|
| 62 |
+
bj(vec_size * (thread_idx % n_vecs)),
|
| 63 |
+
dst(dst_ + bi * dst_ld + bj),
|
| 64 |
+
src(src_ + bi * src_ld + bj) {}
|
| 65 |
+
|
| 66 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
| 67 |
+
METAL_FUNC void load_unsafe() const {
|
| 68 |
+
#pragma clang loop unroll(full)
|
| 69 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
| 70 |
+
#pragma clang loop unroll(full)
|
| 71 |
+
for (short j = 0; j < vec_size; j++) {
|
| 72 |
+
dst[i * dst_ld + j] = src[i * src_ld + j];
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
/* Load from device memory into threadgroup memory - with bound checking */
|
| 78 |
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
| 79 |
+
src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy;
|
| 80 |
+
|
| 81 |
+
// Iterate over rows of block
|
| 82 |
+
#pragma clang loop unroll(full)
|
| 83 |
+
for (short i = 0; i < dst_fd; i += bstride) {
|
| 84 |
+
// Row is in bounds, we check against column
|
| 85 |
+
if ((bi + i) < src_tile_dim.y) {
|
| 86 |
+
// Use fast thread memory for bound checks
|
| 87 |
+
short tmp_idx[vec_size];
|
| 88 |
+
T tmp_val[vec_size];
|
| 89 |
+
|
| 90 |
+
// Make sure tmp_idx only contains valid indices
|
| 91 |
+
#pragma clang loop unroll(full)
|
| 92 |
+
for (short j = 0; j < vec_size; j++) {
|
| 93 |
+
tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Read all valid indices into tmp_val
|
| 97 |
+
#pragma clang loop unroll(full)
|
| 98 |
+
for (short j = 0; j < vec_size; j++) {
|
| 99 |
+
tmp_val[j] = src[i * src_ld + tmp_idx[j]];
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// Zero out unneeded values
|
| 103 |
+
#pragma clang loop unroll(full)
|
| 104 |
+
for (short j = 0; j < vec_size; j++) {
|
| 105 |
+
tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
// Copy values to threadgroup memory
|
| 109 |
+
#pragma clang loop unroll(full)
|
| 110 |
+
for (short j = 0; j < vec_size; j++) {
|
| 111 |
+
dst[i * dst_ld + j] = tmp_val[j];
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// Row is out of bounds, we just fill tgp memory with zeros
|
| 116 |
+
else {
|
| 117 |
+
#pragma clang loop unroll(full)
|
| 118 |
+
for (short j = 0; j < vec_size; j++) {
|
| 119 |
+
dst[i * dst_ld + j] = T(0);
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/* Iteration helper */
|
| 126 |
+
METAL_FUNC void next() {
|
| 127 |
+
src += tstride;
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 132 |
+
// Transforms
|
| 133 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
|
| 135 |
+
template <typename OutT, typename InT>
|
| 136 |
+
struct TransformNone {
|
| 137 |
+
static METAL_FUNC OutT apply(InT x) {
|
| 138 |
+
return static_cast<OutT>(x);
|
| 139 |
+
}
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
template <typename T>
|
| 143 |
+
struct AccumHelper {
|
| 144 |
+
typedef float accum_type;
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 148 |
+
// MMA helper
|
| 149 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 150 |
+
|
| 151 |
+
template <
|
| 152 |
+
typename T,
|
| 153 |
+
int BM,
|
| 154 |
+
int BN,
|
| 155 |
+
int BK,
|
| 156 |
+
int WM,
|
| 157 |
+
int WN,
|
| 158 |
+
bool transpose_a,
|
| 159 |
+
bool transpose_b,
|
| 160 |
+
int tgp_padding_a = 0,
|
| 161 |
+
int tgp_padding_b = 0,
|
| 162 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
| 163 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
| 164 |
+
struct BlockMMA {
|
| 165 |
+
// Warp tile size along M
|
| 166 |
+
MLX_MTL_CONST int TM = BM / (WM * 8);
|
| 167 |
+
// Warp tile size along N
|
| 168 |
+
MLX_MTL_CONST int TN = BN / (WN * 8);
|
| 169 |
+
|
| 170 |
+
// Warp tile simdgroup matrix strides along M
|
| 171 |
+
MLX_MTL_CONST int TM_stride = 8 * WM;
|
| 172 |
+
// Warp tile simdgroup matrix strides along M
|
| 173 |
+
MLX_MTL_CONST int TN_stride = 8 * WN;
|
| 174 |
+
|
| 175 |
+
// Leading dimensions of threadgroup A, B blocks
|
| 176 |
+
MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a;
|
| 177 |
+
MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b;
|
| 178 |
+
|
| 179 |
+
// Strides of A, B along reduction axis
|
| 180 |
+
MLX_MTL_CONST short simd_stride_a =
|
| 181 |
+
transpose_a ? TM_stride : TM_stride * lda_tgp;
|
| 182 |
+
MLX_MTL_CONST short simd_stride_b =
|
| 183 |
+
transpose_b ? TN_stride * ldb_tgp : TN_stride;
|
| 184 |
+
|
| 185 |
+
// Jump between elements
|
| 186 |
+
MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1;
|
| 187 |
+
MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1;
|
| 188 |
+
|
| 189 |
+
// Offsets within threadgroup
|
| 190 |
+
const int tm;
|
| 191 |
+
const int tn;
|
| 192 |
+
|
| 193 |
+
// Simdgroup matrices
|
| 194 |
+
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
| 195 |
+
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
| 196 |
+
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
| 197 |
+
simdgroup_matrix<AccumType, 8, 8>(0)};
|
| 198 |
+
|
| 199 |
+
short sm;
|
| 200 |
+
short sn;
|
| 201 |
+
|
| 202 |
+
/* Constructor */
|
| 203 |
+
METAL_FUNC BlockMMA(
|
| 204 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 205 |
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
| 206 |
+
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
| 207 |
+
short qid = simd_lane_id / 4;
|
| 208 |
+
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
| 209 |
+
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
| 213 |
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
| 214 |
+
// Iterate over BK in blocks of 8
|
| 215 |
+
#pragma clang loop unroll(full)
|
| 216 |
+
for (short kk = 0; kk < BK; kk += 8) {
|
| 217 |
+
short2 offset_a =
|
| 218 |
+
transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm);
|
| 219 |
+
short2 offset_b =
|
| 220 |
+
transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm);
|
| 221 |
+
|
| 222 |
+
const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x;
|
| 223 |
+
const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x;
|
| 224 |
+
|
| 225 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 226 |
+
// Load elements from threadgroup A as simdgroup matrices
|
| 227 |
+
#pragma clang loop unroll(full)
|
| 228 |
+
for (short i = 0; i < TM; i++) {
|
| 229 |
+
Asimd[i].thread_elements()[0] = static_cast<AccumType>(As__[0]);
|
| 230 |
+
Asimd[i].thread_elements()[1] = static_cast<AccumType>(As__[jump_a]);
|
| 231 |
+
As__ += simd_stride_a;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 235 |
+
// Load elements from threadgroup B as simdgroup matrices
|
| 236 |
+
#pragma clang loop unroll(full)
|
| 237 |
+
for (short j = 0; j < TN; j++) {
|
| 238 |
+
Bsimd[j].thread_elements()[0] = static_cast<AccumType>(Bs__[0]);
|
| 239 |
+
Bsimd[j].thread_elements()[1] = static_cast<AccumType>(Bs__[jump_b]);
|
| 240 |
+
Bs__ += simd_stride_b;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 244 |
+
// Multiply and accumulate into result simdgroup matrices
|
| 245 |
+
#pragma clang loop unroll(full)
|
| 246 |
+
for (short i = 0; i < TM; i++) {
|
| 247 |
+
#pragma clang loop unroll(full)
|
| 248 |
+
for (short j = 0; j < TN; j++) {
|
| 249 |
+
simdgroup_multiply_accumulate(
|
| 250 |
+
results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]);
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/* Store results from simdgroup_matrix results into device memory */
|
| 257 |
+
METAL_FUNC void store_result(device T* C, const int ldc) const {
|
| 258 |
+
#pragma clang loop unroll(full)
|
| 259 |
+
for (int i = 0; i < TM; i++) {
|
| 260 |
+
#pragma clang loop unroll(full)
|
| 261 |
+
for (int j = 0; j < TN; j++) {
|
| 262 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] =
|
| 263 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
| 264 |
+
C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] =
|
| 265 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
METAL_FUNC void
|
| 271 |
+
store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const {
|
| 272 |
+
#pragma clang loop unroll(full)
|
| 273 |
+
for (int i = 0; i < TM; i++) {
|
| 274 |
+
if (tm + i * TM_stride + sm < dst_tile_dims.y) {
|
| 275 |
+
#pragma clang loop unroll(full)
|
| 276 |
+
for (int j = 0; j < TN; j++) {
|
| 277 |
+
if (tn + j * TN_stride + sn < dst_tile_dims.x) {
|
| 278 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] =
|
| 279 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[0]);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) {
|
| 283 |
+
C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] =
|
| 284 |
+
Epilogue::apply(results[i * TN + j].thread_elements()[1]);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 293 |
+
// GEMM kernels
|
| 294 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 295 |
+
|
| 296 |
+
template <
|
| 297 |
+
typename T,
|
| 298 |
+
int BM,
|
| 299 |
+
int BN,
|
| 300 |
+
int BK,
|
| 301 |
+
int WM,
|
| 302 |
+
int WN,
|
| 303 |
+
bool transpose_a,
|
| 304 |
+
bool transpose_b,
|
| 305 |
+
bool MN_aligned,
|
| 306 |
+
bool K_aligned,
|
| 307 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
| 308 |
+
typename Epilogue = TransformNone<T, AccumType>>
|
| 309 |
+
struct GEMMKernel {
|
| 310 |
+
MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T);
|
| 311 |
+
MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T);
|
| 312 |
+
MLX_MTL_CONST short tgp_mem_size_a =
|
| 313 |
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
| 314 |
+
MLX_MTL_CONST short tgp_mem_size_b =
|
| 315 |
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
| 316 |
+
MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
| 317 |
+
|
| 318 |
+
MLX_MTL_CONST short tgp_size = WM * WN * 32;
|
| 319 |
+
MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4;
|
| 320 |
+
|
| 321 |
+
using loader_a_t = BlockLoader<
|
| 322 |
+
T,
|
| 323 |
+
BM,
|
| 324 |
+
BK,
|
| 325 |
+
BK,
|
| 326 |
+
vec_size,
|
| 327 |
+
tgp_size,
|
| 328 |
+
transpose_a,
|
| 329 |
+
true,
|
| 330 |
+
tgp_padding_a>;
|
| 331 |
+
using loader_b_t = BlockLoader<
|
| 332 |
+
T,
|
| 333 |
+
BK,
|
| 334 |
+
BN,
|
| 335 |
+
BK,
|
| 336 |
+
vec_size,
|
| 337 |
+
tgp_size,
|
| 338 |
+
transpose_b,
|
| 339 |
+
false,
|
| 340 |
+
tgp_padding_b>;
|
| 341 |
+
using mma_t = BlockMMA<
|
| 342 |
+
T,
|
| 343 |
+
BM,
|
| 344 |
+
BN,
|
| 345 |
+
BK,
|
| 346 |
+
WM,
|
| 347 |
+
WN,
|
| 348 |
+
transpose_a,
|
| 349 |
+
transpose_b,
|
| 350 |
+
tgp_padding_a,
|
| 351 |
+
tgp_padding_b,
|
| 352 |
+
AccumType,
|
| 353 |
+
Epilogue>;
|
| 354 |
+
|
| 355 |
+
/* Main kernel function */
|
| 356 |
+
static METAL_FUNC void run(
|
| 357 |
+
const device T* A [[buffer(0)]],
|
| 358 |
+
const device T* B [[buffer(1)]],
|
| 359 |
+
device T* C [[buffer(2)]],
|
| 360 |
+
const constant int& M [[buffer(3)]],
|
| 361 |
+
const constant int& N [[buffer(4)]],
|
| 362 |
+
const constant int& K [[buffer(5)]],
|
| 363 |
+
const constant int& batch_stride_a [[buffer(6)]],
|
| 364 |
+
const constant int& batch_stride_b [[buffer(7)]],
|
| 365 |
+
const constant int& batch_stride_c [[buffer(8)]],
|
| 366 |
+
threadgroup T* tgp_memory [[threadgroup(0)]],
|
| 367 |
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
| 368 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 369 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 370 |
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
| 371 |
+
// Pacifying compiler
|
| 372 |
+
(void)lid;
|
| 373 |
+
|
| 374 |
+
// Adjust for batch
|
| 375 |
+
A += batch_stride_a * tid.z;
|
| 376 |
+
B += batch_stride_b * tid.z;
|
| 377 |
+
C += batch_stride_c * tid.z;
|
| 378 |
+
|
| 379 |
+
// Adjust for transpose
|
| 380 |
+
const int lda_dev = transpose_a ? M : K;
|
| 381 |
+
const int ldb_dev = transpose_b ? K : N;
|
| 382 |
+
|
| 383 |
+
// Find block in A, B, C
|
| 384 |
+
const int c_row = tid.y * BM;
|
| 385 |
+
const int c_col = tid.x * BN;
|
| 386 |
+
|
| 387 |
+
A += transpose_a ? c_row : c_row * K;
|
| 388 |
+
B += transpose_b ? c_col * K : c_col;
|
| 389 |
+
C += c_row * N + c_col;
|
| 390 |
+
|
| 391 |
+
// Prepare threadgroup memory for loading
|
| 392 |
+
threadgroup T* As = tgp_memory;
|
| 393 |
+
threadgroup T* Bs = tgp_memory + tgp_mem_size_a;
|
| 394 |
+
|
| 395 |
+
// Prepare threadgroup loading operations
|
| 396 |
+
loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id);
|
| 397 |
+
loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id);
|
| 398 |
+
|
| 399 |
+
// Prepare threadgroup mma operation
|
| 400 |
+
mma_t mma_op(simd_group_id, simd_lane_id);
|
| 401 |
+
|
| 402 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 403 |
+
// MNK aligned loop
|
| 404 |
+
if (MN_aligned && K_aligned) {
|
| 405 |
+
for (int k = 0; k < K; k += BK) {
|
| 406 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 407 |
+
// Load elements into threadgroup
|
| 408 |
+
loader_a.load_unsafe();
|
| 409 |
+
loader_b.load_unsafe();
|
| 410 |
+
|
| 411 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 412 |
+
|
| 413 |
+
// Multiply and accumulate threadgroup elements
|
| 414 |
+
mma_op.mma(As, Bs);
|
| 415 |
+
|
| 416 |
+
// Prepare for next iteration
|
| 417 |
+
loader_a.next();
|
| 418 |
+
loader_b.next();
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 422 |
+
|
| 423 |
+
// Store results to device memory
|
| 424 |
+
mma_op.store_result(C, N);
|
| 425 |
+
return;
|
| 426 |
+
|
| 427 |
+
}
|
| 428 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 429 |
+
// MN aligned, K unaligned loop
|
| 430 |
+
else if (MN_aligned && !K_aligned) {
|
| 431 |
+
// Main loop
|
| 432 |
+
int k = 0;
|
| 433 |
+
for (; k + BK <= K; k += BK) {
|
| 434 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 435 |
+
// Load elements into threadgroup
|
| 436 |
+
loader_a.load_unsafe();
|
| 437 |
+
loader_b.load_unsafe();
|
| 438 |
+
|
| 439 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 440 |
+
|
| 441 |
+
// Multiply and accumulate threadgroup elements
|
| 442 |
+
mma_op.mma(As, Bs);
|
| 443 |
+
|
| 444 |
+
// Prepare for next iteration
|
| 445 |
+
loader_a.next();
|
| 446 |
+
loader_b.next();
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
// Loop tail
|
| 450 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 451 |
+
|
| 452 |
+
loader_a.load_safe(short2(K - k, BM));
|
| 453 |
+
loader_b.load_safe(short2(BN, K - k));
|
| 454 |
+
|
| 455 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 456 |
+
|
| 457 |
+
mma_op.mma(As, Bs);
|
| 458 |
+
|
| 459 |
+
// Store results to device memory
|
| 460 |
+
mma_op.store_result(C, N);
|
| 461 |
+
return;
|
| 462 |
+
|
| 463 |
+
}
|
| 464 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 465 |
+
// MNK unaligned loop
|
| 466 |
+
else { // Loop over K - unaligned case
|
| 467 |
+
|
| 468 |
+
short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row));
|
| 469 |
+
|
| 470 |
+
if (src_tile_dims.y == BM && src_tile_dims.x == BN) {
|
| 471 |
+
int k = 0;
|
| 472 |
+
for (; k + BK <= K; k += BK) {
|
| 473 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 474 |
+
// Load elements into threadgroup
|
| 475 |
+
loader_a.load_unsafe();
|
| 476 |
+
loader_b.load_unsafe();
|
| 477 |
+
|
| 478 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 479 |
+
|
| 480 |
+
// Multiply and accumulate threadgroup elements
|
| 481 |
+
mma_op.mma(As, Bs);
|
| 482 |
+
|
| 483 |
+
// Prepare for next iteration
|
| 484 |
+
loader_a.next();
|
| 485 |
+
loader_b.next();
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 489 |
+
|
| 490 |
+
if (k < K) {
|
| 491 |
+
loader_a.load_safe(short2(K - k, BM));
|
| 492 |
+
loader_b.load_safe(short2(BN, K - k));
|
| 493 |
+
|
| 494 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 495 |
+
|
| 496 |
+
mma_op.mma(As, Bs);
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
mma_op.store_result(C, N);
|
| 500 |
+
return;
|
| 501 |
+
|
| 502 |
+
} else {
|
| 503 |
+
int k = 0;
|
| 504 |
+
for (; k + BK <= K; k += BK) {
|
| 505 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 506 |
+
// Load elements into threadgroup
|
| 507 |
+
loader_a.load_safe(short2(BK, src_tile_dims.y));
|
| 508 |
+
loader_b.load_safe(short2(src_tile_dims.x, BK));
|
| 509 |
+
|
| 510 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 511 |
+
|
| 512 |
+
// Multiply and accumulate threadgroup elements
|
| 513 |
+
mma_op.mma(As, Bs);
|
| 514 |
+
|
| 515 |
+
// Prepare for next iteration
|
| 516 |
+
loader_a.next();
|
| 517 |
+
loader_b.next();
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 521 |
+
|
| 522 |
+
if (k < K) {
|
| 523 |
+
loader_a.load_safe(short2(K - k, src_tile_dims.y));
|
| 524 |
+
loader_b.load_safe(short2(src_tile_dims.x, K - k));
|
| 525 |
+
|
| 526 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 527 |
+
|
| 528 |
+
mma_op.mma(As, Bs);
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 532 |
+
mma_op.store_result_safe(C, N, src_tile_dims);
|
| 533 |
+
|
| 534 |
+
return;
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_atomic>
|
| 6 |
+
#include <metal_simdgroup>
|
| 7 |
+
|
| 8 |
+
#include "mlx/backend/metal/kernels/atomic.h"
|
| 9 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
| 10 |
+
#include "mlx/backend/metal/kernels/utils.h"
|
| 11 |
+
|
| 12 |
+
union bool4_or_uint {
|
| 13 |
+
bool4 b;
|
| 14 |
+
unsigned int i;
|
| 15 |
+
};
|
| 16 |
+
|
| 17 |
+
struct None {
|
| 18 |
+
template <typename T>
|
| 19 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
| 20 |
+
mlx_atomic_store_explicit(out, val, offset);
|
| 21 |
+
}
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
struct And {
|
| 25 |
+
bool simd_reduce(bool val) {
|
| 26 |
+
return simd_all(val);
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
static constexpr constant bool init = true;
|
| 30 |
+
|
| 31 |
+
void atomic_update(
|
| 32 |
+
device mlx_atomic<unsigned int>* out,
|
| 33 |
+
bool val,
|
| 34 |
+
int elem_idx,
|
| 35 |
+
int offset = 0) {
|
| 36 |
+
if (!val) {
|
| 37 |
+
bool4_or_uint update;
|
| 38 |
+
update.b = {true, true, true, true};
|
| 39 |
+
update.b[elem_idx] = false;
|
| 40 |
+
mlx_atomic_fetch_and_explicit(out, update.i, offset);
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
| 45 |
+
if (!val) {
|
| 46 |
+
mlx_atomic_store_explicit(out, val, offset);
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Non atomic update
|
| 51 |
+
void update(device bool* out, bool val) {
|
| 52 |
+
*out &= val;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
// Operator
|
| 56 |
+
bool operator()(bool a, bool b) {
|
| 57 |
+
return a && b;
|
| 58 |
+
}
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
struct Or {
|
| 62 |
+
bool simd_reduce(bool val) {
|
| 63 |
+
return simd_any(val);
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
static constexpr constant bool init = false;
|
| 67 |
+
|
| 68 |
+
void atomic_update(
|
| 69 |
+
device mlx_atomic<unsigned int>* out,
|
| 70 |
+
bool val,
|
| 71 |
+
int elem_idx,
|
| 72 |
+
int offset = 0) {
|
| 73 |
+
if (val) {
|
| 74 |
+
bool4_or_uint update;
|
| 75 |
+
update.b = {false, false, false, false};
|
| 76 |
+
update.b[elem_idx] = true;
|
| 77 |
+
mlx_atomic_fetch_or_explicit(out, update.i, offset);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
| 82 |
+
if (val) {
|
| 83 |
+
mlx_atomic_store_explicit(out, val, offset);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// Non atomic update
|
| 88 |
+
void update(device bool* out, bool val) {
|
| 89 |
+
*out |= val;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
// Operator
|
| 93 |
+
bool operator()(bool a, bool b) {
|
| 94 |
+
return a || b;
|
| 95 |
+
}
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
template <typename U>
|
| 99 |
+
struct Sum {
|
| 100 |
+
template <typename T>
|
| 101 |
+
T simd_reduce(T val) {
|
| 102 |
+
return simd_sum(val);
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
static constexpr constant U init = U(0);
|
| 106 |
+
|
| 107 |
+
template <typename T>
|
| 108 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
| 109 |
+
mlx_atomic_fetch_add_explicit(out, val, offset);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// Operator
|
| 113 |
+
U operator()(U a, U b) {
|
| 114 |
+
return a + b;
|
| 115 |
+
}
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
template <typename U>
|
| 119 |
+
struct Prod {
|
| 120 |
+
template <typename T>
|
| 121 |
+
T simd_reduce(T val) {
|
| 122 |
+
return simd_product(val);
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
static constexpr constant U init = U(1);
|
| 126 |
+
|
| 127 |
+
template <typename T>
|
| 128 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
| 129 |
+
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Operator
|
| 133 |
+
U operator()(U a, U b) {
|
| 134 |
+
return a * b;
|
| 135 |
+
}
|
| 136 |
+
};
|
| 137 |
+
|
| 138 |
+
template <typename U>
|
| 139 |
+
struct Min {
|
| 140 |
+
template <typename T>
|
| 141 |
+
T simd_reduce(T val) {
|
| 142 |
+
return simd_min(val);
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
static constexpr constant U init = Limits<U>::max;
|
| 146 |
+
|
| 147 |
+
template <typename T>
|
| 148 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
| 149 |
+
mlx_atomic_fetch_min_explicit(out, val, offset);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Operator
|
| 153 |
+
U operator()(U a, U b) {
|
| 154 |
+
return a < b ? a : b;
|
| 155 |
+
}
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
template <typename U>
|
| 159 |
+
struct Max {
|
| 160 |
+
template <typename T>
|
| 161 |
+
T simd_reduce(T val) {
|
| 162 |
+
return simd_max(val);
|
| 163 |
+
};
|
| 164 |
+
|
| 165 |
+
static constexpr constant U init = Limits<U>::min;
|
| 166 |
+
|
| 167 |
+
template <typename T>
|
| 168 |
+
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
| 169 |
+
mlx_atomic_fetch_max_explicit(out, val, offset);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// Operator
|
| 173 |
+
U operator()(U a, U b) {
|
| 174 |
+
return a > b ? a : b;
|
| 175 |
+
}
|
| 176 |
+
};
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_math>
|
| 6 |
+
#include "mlx/backend/metal/kernels/bf16.h"
|
| 7 |
+
#include "mlx/backend/metal/kernels/complex.h"
|
| 8 |
+
|
| 9 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 10 |
+
// Type limits utils
|
| 11 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 12 |
+
|
| 13 |
+
template <typename U>
|
| 14 |
+
struct Limits {
|
| 15 |
+
static const constant U max;
|
| 16 |
+
static const constant U min;
|
| 17 |
+
static const constant U finite_max;
|
| 18 |
+
static const constant U finite_min;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
#define instantiate_default_limit(type) \
|
| 22 |
+
template <> \
|
| 23 |
+
struct Limits<type> { \
|
| 24 |
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
| 25 |
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
| 26 |
+
static constexpr constant type finite_max = \
|
| 27 |
+
metal::numeric_limits<type>::max(); \
|
| 28 |
+
static constexpr constant type finite_min = \
|
| 29 |
+
metal::numeric_limits<type>::min(); \
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
instantiate_default_limit(uint8_t);
|
| 33 |
+
instantiate_default_limit(uint16_t);
|
| 34 |
+
instantiate_default_limit(uint32_t);
|
| 35 |
+
instantiate_default_limit(uint64_t);
|
| 36 |
+
instantiate_default_limit(int8_t);
|
| 37 |
+
instantiate_default_limit(int16_t);
|
| 38 |
+
instantiate_default_limit(int32_t);
|
| 39 |
+
instantiate_default_limit(int64_t);
|
| 40 |
+
|
| 41 |
+
#define instantiate_float_limit(type) \
|
| 42 |
+
template <> \
|
| 43 |
+
struct Limits<type> { \
|
| 44 |
+
static constexpr constant type max = \
|
| 45 |
+
metal::numeric_limits<type>::infinity(); \
|
| 46 |
+
static constexpr constant type min = \
|
| 47 |
+
-metal::numeric_limits<type>::infinity(); \
|
| 48 |
+
static constexpr constant type finite_max = \
|
| 49 |
+
metal::numeric_limits<type>::max(); \
|
| 50 |
+
static constexpr constant type finite_min = \
|
| 51 |
+
-metal::numeric_limits<type>::max(); \
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
instantiate_float_limit(half);
|
| 55 |
+
instantiate_float_limit(float);
|
| 56 |
+
instantiate_float_limit(bfloat16_t);
|
| 57 |
+
|
| 58 |
+
template <>
|
| 59 |
+
struct Limits<bool> {
|
| 60 |
+
static constexpr constant bool max = true;
|
| 61 |
+
static constexpr constant bool min = false;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
// Indexing utils
|
| 66 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 67 |
+
|
| 68 |
+
inline size_t elem_to_loc(
|
| 69 |
+
uint elem,
|
| 70 |
+
device const int* shape,
|
| 71 |
+
device const size_t* strides,
|
| 72 |
+
int ndim) {
|
| 73 |
+
size_t loc = 0;
|
| 74 |
+
for (int i = ndim - 1; i >= 0; --i) {
|
| 75 |
+
loc += (elem % shape[i]) * strides[i];
|
| 76 |
+
elem /= shape[i];
|
| 77 |
+
}
|
| 78 |
+
return loc;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
inline size_t elem_to_loc(
|
| 82 |
+
uint elem,
|
| 83 |
+
constant const int* shape,
|
| 84 |
+
constant const size_t* strides,
|
| 85 |
+
int ndim) {
|
| 86 |
+
size_t loc = 0;
|
| 87 |
+
for (int i = ndim - 1; i >= 0; --i) {
|
| 88 |
+
loc += (elem % shape[i]) * strides[i];
|
| 89 |
+
elem /= shape[i];
|
| 90 |
+
}
|
| 91 |
+
return loc;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <int NDIM>
|
| 95 |
+
inline uint2 elem_to_loc_2_nd(
|
| 96 |
+
uint3 elem,
|
| 97 |
+
constant const int shape[NDIM],
|
| 98 |
+
constant const size_t a_strides[NDIM],
|
| 99 |
+
constant const size_t b_strides[NDIM]) {
|
| 100 |
+
uint2 loc = {
|
| 101 |
+
static_cast<uint>(
|
| 102 |
+
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
| 103 |
+
static_cast<uint>(
|
| 104 |
+
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
| 105 |
+
for (int d = NDIM - 3; d >= 0; --d) {
|
| 106 |
+
uint l = elem.z % shape[d];
|
| 107 |
+
loc.x += l * a_strides[d];
|
| 108 |
+
loc.y += l * b_strides[d];
|
| 109 |
+
elem.z /= shape[d];
|
| 110 |
+
}
|
| 111 |
+
return loc;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template <int NDIM>
|
| 115 |
+
inline size_t elem_to_loc_nd(
|
| 116 |
+
uint3 elem,
|
| 117 |
+
constant const int shape[NDIM],
|
| 118 |
+
constant const size_t strides[NDIM]) {
|
| 119 |
+
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
| 120 |
+
for (int d = NDIM - 3; d >= 0; --d) {
|
| 121 |
+
loc += (elem.z % shape[d]) * strides[d];
|
| 122 |
+
elem.z /= shape[d];
|
| 123 |
+
}
|
| 124 |
+
return loc;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
| 128 |
+
return elem * stride;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
| 132 |
+
return elem.x * strides[1] + elem.y * strides[0];
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
| 136 |
+
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Non templated version to handle arbitrary dims
|
| 140 |
+
inline size_t elem_to_loc(
|
| 141 |
+
uint3 elem,
|
| 142 |
+
constant const int* shape,
|
| 143 |
+
constant const size_t* strides,
|
| 144 |
+
int ndim) {
|
| 145 |
+
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
| 146 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
| 147 |
+
loc += (elem.z % shape[d]) * strides[d];
|
| 148 |
+
elem.z /= shape[d];
|
| 149 |
+
}
|
| 150 |
+
return loc;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
inline uint2 elem_to_loc_2_nd(
|
| 154 |
+
uint3 elem,
|
| 155 |
+
constant const int* shape,
|
| 156 |
+
constant const size_t* a_strides,
|
| 157 |
+
constant const size_t* b_strides,
|
| 158 |
+
int ndim) {
|
| 159 |
+
uint2 loc = {
|
| 160 |
+
static_cast<uint>(
|
| 161 |
+
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
| 162 |
+
static_cast<uint>(
|
| 163 |
+
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
| 164 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
| 165 |
+
uint l = elem.z % shape[d];
|
| 166 |
+
loc.x += l * a_strides[d];
|
| 167 |
+
loc.y += l * b_strides[d];
|
| 168 |
+
elem.z /= shape[d];
|
| 169 |
+
}
|
| 170 |
+
return loc;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
template <int NDIM>
|
| 174 |
+
inline uint elem_to_loc_nd(
|
| 175 |
+
uint elem,
|
| 176 |
+
device const int* shape,
|
| 177 |
+
device const size_t* strides);
|
| 178 |
+
|
| 179 |
+
template <>
|
| 180 |
+
inline uint elem_to_loc_nd<1>(
|
| 181 |
+
uint elem,
|
| 182 |
+
device const int* shape,
|
| 183 |
+
device const size_t* strides) {
|
| 184 |
+
return (elem % shape[0]) * strides[0];
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
template <>
|
| 188 |
+
inline uint elem_to_loc_nd<2>(
|
| 189 |
+
uint elem,
|
| 190 |
+
device const int* shape,
|
| 191 |
+
device const size_t* strides) {
|
| 192 |
+
uint loc = (elem % shape[1]) * strides[1];
|
| 193 |
+
elem /= shape[1];
|
| 194 |
+
loc += (elem % shape[0]) * strides[0];
|
| 195 |
+
return loc;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
template <>
|
| 199 |
+
inline uint elem_to_loc_nd<3>(
|
| 200 |
+
uint elem,
|
| 201 |
+
device const int* shape,
|
| 202 |
+
device const size_t* strides) {
|
| 203 |
+
uint loc = (elem % shape[2]) * strides[2];
|
| 204 |
+
elem /= shape[2];
|
| 205 |
+
loc += (elem % shape[1]) * strides[1];
|
| 206 |
+
elem /= shape[1];
|
| 207 |
+
loc += (elem % shape[0]) * strides[0];
|
| 208 |
+
return loc;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
template <>
|
| 212 |
+
inline uint elem_to_loc_nd<4>(
|
| 213 |
+
uint elem,
|
| 214 |
+
device const int* shape,
|
| 215 |
+
device const size_t* strides) {
|
| 216 |
+
uint loc = (elem % shape[3]) * strides[3];
|
| 217 |
+
elem /= shape[3];
|
| 218 |
+
loc += (elem % shape[2]) * strides[2];
|
| 219 |
+
elem /= shape[2];
|
| 220 |
+
loc += (elem % shape[1]) * strides[1];
|
| 221 |
+
elem /= shape[1];
|
| 222 |
+
loc += (elem % shape[0]) * strides[0];
|
| 223 |
+
return loc;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 227 |
+
// Calculation utils
|
| 228 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 229 |
+
|
| 230 |
+
/** Compute ceil((float)N/(float)M) */
|
| 231 |
+
inline size_t ceildiv(size_t N, size_t M) {
|
| 232 |
+
return (N + M - 1) / M;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
| 236 |
+
inline float log1p(float x) {
|
| 237 |
+
float xp1 = 1.0f + x;
|
| 238 |
+
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
inline bfloat16_t log1p(bfloat16_t x) {
|
| 242 |
+
float xp1 = 1.0f + static_cast<float>(x);
|
| 243 |
+
bfloat16_t ret =
|
| 244 |
+
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
| 245 |
+
return ret;
|
| 246 |
+
}
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include <algorithm>
|
| 4 |
+
#include <cassert>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
|
| 7 |
+
#include "mlx/backend/metal/copy.h"
|
| 8 |
+
#include "mlx/backend/metal/device.h"
|
| 9 |
+
#include "mlx/backend/metal/mps/gemm.h"
|
| 10 |
+
#include "mlx/backend/metal/utils.h"
|
| 11 |
+
#include "mlx/utils.h"
|
| 12 |
+
|
| 13 |
+
namespace mlx::core {
|
| 14 |
+
|
| 15 |
+
void mlx_matmul(
|
| 16 |
+
const Stream& s,
|
| 17 |
+
metal::Device& d,
|
| 18 |
+
const array& a,
|
| 19 |
+
const array& b,
|
| 20 |
+
array& out,
|
| 21 |
+
int M,
|
| 22 |
+
int N,
|
| 23 |
+
int K,
|
| 24 |
+
int batch_size_out,
|
| 25 |
+
int lda,
|
| 26 |
+
int ldb,
|
| 27 |
+
bool transpose_a,
|
| 28 |
+
bool transpose_b,
|
| 29 |
+
std::vector<array>& copies);
|
| 30 |
+
|
| 31 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <future>
|
| 6 |
+
#include <memory>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
#include "mlx/array.h"
|
| 10 |
+
#include "mlx/stream.h"
|
| 11 |
+
|
| 12 |
+
namespace mlx::core::metal {
|
| 13 |
+
|
| 14 |
+
constexpr bool is_available() {
|
| 15 |
+
#ifdef _METAL_
|
| 16 |
+
return true;
|
| 17 |
+
#else
|
| 18 |
+
return false;
|
| 19 |
+
#endif
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
void new_stream(Stream stream);
|
| 23 |
+
std::shared_ptr<void> new_scoped_memory_pool();
|
| 24 |
+
|
| 25 |
+
std::function<void()> make_task(
|
| 26 |
+
array& arr,
|
| 27 |
+
std::vector<std::shared_future<void>> deps,
|
| 28 |
+
std::shared_ptr<std::promise<void>> p,
|
| 29 |
+
bool retain_graph);
|
| 30 |
+
|
| 31 |
+
} // namespace mlx::core::metal
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <Metal/Metal.hpp>
|
| 6 |
+
|
| 7 |
+
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
|
| 8 |
+
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
|
| 9 |
+
|
| 10 |
+
namespace MTL::Private::Class {
|
| 11 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
|
| 12 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
|
| 13 |
+
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
|
| 14 |
+
_MTL_PRIVATE_DEF_CLS(MPSVector);
|
| 15 |
+
_MTL_PRIVATE_DEF_CLS(MPSKernel);
|
| 16 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
|
| 17 |
+
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
|
| 18 |
+
} // namespace MTL::Private::Class
|
| 19 |
+
|
| 20 |
+
namespace MTL::Private::Selector {
|
| 21 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 22 |
+
matrixDescriptorWithRows_columns_rowBytes_dataType,
|
| 23 |
+
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
|
| 24 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 25 |
+
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
|
| 26 |
+
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
|
| 27 |
+
_MTL_PRIVATE_DEF_SEL(rows, "rows");
|
| 28 |
+
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
|
| 29 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 30 |
+
initWithDevice_,
|
| 31 |
+
"initWithDevice:transposeLeft:transposeRight:"
|
| 32 |
+
"resultRows:resultColumns:interiorColumns:alpha:beta:");
|
| 33 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 34 |
+
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
|
| 35 |
+
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
|
| 36 |
+
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
|
| 37 |
+
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
|
| 38 |
+
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
|
| 39 |
+
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
|
| 40 |
+
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
|
| 41 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 42 |
+
vectorDescriptorWithLength_dataType,
|
| 43 |
+
"vectorDescriptorWithLength:dataType:");
|
| 44 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 45 |
+
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
|
| 46 |
+
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
|
| 47 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 48 |
+
initWithDevice_transpose_rows_columns_alpha_beta,
|
| 49 |
+
"initWithDevice:transpose:rows:columns:alpha:beta:");
|
| 50 |
+
_MTL_PRIVATE_DEF_SEL(
|
| 51 |
+
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
|
| 52 |
+
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
|
| 53 |
+
} // namespace MTL::Private::Selector
|
| 54 |
+
|
| 55 |
+
namespace MPS {
|
| 56 |
+
|
| 57 |
+
typedef enum DataType : uint32_t {
|
| 58 |
+
DataTypeFloatBit = 0x10000000,
|
| 59 |
+
DataTypeAlternateEncodingBit = 0x80000000,
|
| 60 |
+
DataTypeFloat16 = DataTypeFloatBit | 16,
|
| 61 |
+
DataTypeFloat32 = DataTypeFloatBit | 32,
|
| 62 |
+
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
|
| 63 |
+
} DataType;
|
| 64 |
+
|
| 65 |
+
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
|
| 66 |
+
public:
|
| 67 |
+
static class MatrixDescriptor* matrixDescriptor(
|
| 68 |
+
NS::UInteger rows,
|
| 69 |
+
NS::UInteger columns,
|
| 70 |
+
NS::UInteger rowBytes,
|
| 71 |
+
NS::UInteger dataType);
|
| 72 |
+
static class MatrixDescriptor* matrixDescriptor(
|
| 73 |
+
NS::UInteger rows,
|
| 74 |
+
NS::UInteger columns,
|
| 75 |
+
NS::UInteger matrices,
|
| 76 |
+
NS::UInteger rowBytes,
|
| 77 |
+
NS::UInteger matrixBytes,
|
| 78 |
+
NS::UInteger dataType);
|
| 79 |
+
NS::UInteger rows() const;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
class Matrix : public NS::Referencing<Matrix> {
|
| 83 |
+
public:
|
| 84 |
+
static class Matrix* alloc();
|
| 85 |
+
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
| 86 |
+
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
class Kernel : public NS::Referencing<Kernel> {
|
| 90 |
+
public:
|
| 91 |
+
NS::String* label() const;
|
| 92 |
+
MTL::Device* device() const;
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
class MatrixMultiplication
|
| 96 |
+
: public NS::Referencing<MatrixMultiplication, Kernel> {
|
| 97 |
+
public:
|
| 98 |
+
static class MatrixMultiplication* alloc();
|
| 99 |
+
|
| 100 |
+
MatrixMultiplication* init(
|
| 101 |
+
MTL::Device* device,
|
| 102 |
+
bool transposeLeft,
|
| 103 |
+
bool transposeRight,
|
| 104 |
+
NS::UInteger resultRows,
|
| 105 |
+
NS::UInteger resultColumns,
|
| 106 |
+
NS::UInteger interiorColumns,
|
| 107 |
+
double alpha,
|
| 108 |
+
double beta);
|
| 109 |
+
|
| 110 |
+
void encodeToCommandBuffer(
|
| 111 |
+
MTL::CommandBuffer* commandBuffer,
|
| 112 |
+
Matrix* leftMatrix,
|
| 113 |
+
Matrix* rightMatrix,
|
| 114 |
+
Matrix* resultMatrix);
|
| 115 |
+
|
| 116 |
+
void setLeftMatrixOrigin(MTL::Origin origin);
|
| 117 |
+
void setRightMatrixOrigin(MTL::Origin origin);
|
| 118 |
+
void setResultMatrixOrigin(MTL::Origin origin);
|
| 119 |
+
void setBatchStart(NS::UInteger batchStart);
|
| 120 |
+
void setBatchSize(NS::UInteger batchSize);
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
|
| 124 |
+
public:
|
| 125 |
+
static class VectorDescriptor* vectorDescriptor(
|
| 126 |
+
NS::UInteger length,
|
| 127 |
+
NS::UInteger dataType);
|
| 128 |
+
static class VectorDescriptor* vectorDescriptor(
|
| 129 |
+
NS::UInteger length,
|
| 130 |
+
NS::UInteger vectors,
|
| 131 |
+
NS::UInteger vectorBytes,
|
| 132 |
+
NS::UInteger dataType);
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
class Vector : public NS::Referencing<Vector> {
|
| 136 |
+
public:
|
| 137 |
+
static class Vector* alloc();
|
| 138 |
+
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
| 139 |
+
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
class MatrixVectorMultiplication
|
| 143 |
+
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
|
| 144 |
+
public:
|
| 145 |
+
static class MatrixVectorMultiplication* alloc();
|
| 146 |
+
|
| 147 |
+
MatrixVectorMultiplication* init(
|
| 148 |
+
MTL::Device* device,
|
| 149 |
+
bool transpose,
|
| 150 |
+
NS::UInteger rows,
|
| 151 |
+
NS::UInteger columns,
|
| 152 |
+
double alpha,
|
| 153 |
+
double beta);
|
| 154 |
+
|
| 155 |
+
void encodeToCommandBuffer(
|
| 156 |
+
MTL::CommandBuffer* commandBuffer,
|
| 157 |
+
Matrix* inputMatrix,
|
| 158 |
+
Vector* inputVector,
|
| 159 |
+
Vector* resultVector);
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
| 163 |
+
NS::UInteger rows,
|
| 164 |
+
NS::UInteger columns,
|
| 165 |
+
NS::UInteger rowBytes,
|
| 166 |
+
NS::UInteger dataType) {
|
| 167 |
+
return Object::sendMessage<MatrixDescriptor*>(
|
| 168 |
+
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
| 169 |
+
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
|
| 170 |
+
rows,
|
| 171 |
+
columns,
|
| 172 |
+
rowBytes,
|
| 173 |
+
dataType);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
|
| 177 |
+
NS::UInteger rows,
|
| 178 |
+
NS::UInteger columns,
|
| 179 |
+
NS::UInteger matrices,
|
| 180 |
+
NS::UInteger rowBytes,
|
| 181 |
+
NS::UInteger matrixBytes,
|
| 182 |
+
NS::UInteger dataType) {
|
| 183 |
+
return Object::sendMessage<MatrixDescriptor*>(
|
| 184 |
+
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
|
| 185 |
+
_MPS_PRIVATE_SEL(
|
| 186 |
+
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
|
| 187 |
+
rows,
|
| 188 |
+
columns,
|
| 189 |
+
matrices,
|
| 190 |
+
rowBytes,
|
| 191 |
+
matrixBytes,
|
| 192 |
+
dataType);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
|
| 196 |
+
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
_MTL_INLINE Matrix* Matrix::alloc() {
|
| 200 |
+
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
_MTL_INLINE Matrix* Matrix::init(
|
| 204 |
+
MTL::Buffer* buffer,
|
| 205 |
+
MatrixDescriptor* descriptor) {
|
| 206 |
+
return Object::sendMessage<Matrix*>(
|
| 207 |
+
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
_MTL_INLINE Matrix* Matrix::init(
|
| 211 |
+
const MTL::Buffer* buffer,
|
| 212 |
+
MatrixDescriptor* descriptor) {
|
| 213 |
+
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
_MTL_INLINE NS::String* Kernel::label() const {
|
| 217 |
+
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
_MTL_INLINE MTL::Device* Kernel::device() const {
|
| 221 |
+
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
|
| 225 |
+
return NS::Object::alloc<MatrixMultiplication>(
|
| 226 |
+
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
|
| 230 |
+
MTL::Device* device,
|
| 231 |
+
bool transposeLeft,
|
| 232 |
+
bool transposeRight,
|
| 233 |
+
NS::UInteger resultRows,
|
| 234 |
+
NS::UInteger resultColumns,
|
| 235 |
+
NS::UInteger interiorColumns,
|
| 236 |
+
double alpha,
|
| 237 |
+
double beta) {
|
| 238 |
+
return Object::sendMessage<MatrixMultiplication*>(
|
| 239 |
+
this,
|
| 240 |
+
_MPS_PRIVATE_SEL(initWithDevice_),
|
| 241 |
+
device,
|
| 242 |
+
transposeLeft,
|
| 243 |
+
transposeRight,
|
| 244 |
+
resultRows,
|
| 245 |
+
resultColumns,
|
| 246 |
+
interiorColumns,
|
| 247 |
+
alpha,
|
| 248 |
+
beta);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
|
| 252 |
+
MTL::CommandBuffer* commandBuffer,
|
| 253 |
+
Matrix* leftMatrix,
|
| 254 |
+
Matrix* rightMatrix,
|
| 255 |
+
Matrix* resultMatrix) {
|
| 256 |
+
return Object::sendMessage<void>(
|
| 257 |
+
this,
|
| 258 |
+
_MPS_PRIVATE_SEL(
|
| 259 |
+
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
|
| 260 |
+
commandBuffer,
|
| 261 |
+
leftMatrix,
|
| 262 |
+
rightMatrix,
|
| 263 |
+
resultMatrix);
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
|
| 267 |
+
Object::sendMessage<void>(
|
| 268 |
+
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
|
| 272 |
+
MTL::Origin origin) {
|
| 273 |
+
Object::sendMessage<void>(
|
| 274 |
+
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
|
| 278 |
+
MTL::Origin origin) {
|
| 279 |
+
Object::sendMessage<void>(
|
| 280 |
+
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
|
| 284 |
+
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
|
| 288 |
+
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
| 292 |
+
NS::UInteger length,
|
| 293 |
+
NS::UInteger dataType) {
|
| 294 |
+
return Object::sendMessage<VectorDescriptor*>(
|
| 295 |
+
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
| 296 |
+
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
|
| 297 |
+
length,
|
| 298 |
+
dataType);
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
|
| 302 |
+
NS::UInteger length,
|
| 303 |
+
NS::UInteger vectors,
|
| 304 |
+
NS::UInteger vectorBytes,
|
| 305 |
+
NS::UInteger dataType) {
|
| 306 |
+
return Object::sendMessage<VectorDescriptor*>(
|
| 307 |
+
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
|
| 308 |
+
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
|
| 309 |
+
length,
|
| 310 |
+
vectors,
|
| 311 |
+
vectorBytes,
|
| 312 |
+
dataType);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
_MTL_INLINE Vector* Vector::alloc() {
|
| 316 |
+
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
_MTL_INLINE Vector* Vector::init(
|
| 320 |
+
MTL::Buffer* buffer,
|
| 321 |
+
VectorDescriptor* descriptor) {
|
| 322 |
+
return Object::sendMessage<Vector*>(
|
| 323 |
+
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
_MTL_INLINE Vector* Vector::init(
|
| 327 |
+
const MTL::Buffer* buffer,
|
| 328 |
+
VectorDescriptor* descriptor) {
|
| 329 |
+
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
|
| 333 |
+
return NS::Object::alloc<MatrixVectorMultiplication>(
|
| 334 |
+
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
|
| 338 |
+
MTL::Device* device,
|
| 339 |
+
bool transpose,
|
| 340 |
+
NS::UInteger rows,
|
| 341 |
+
NS::UInteger columns,
|
| 342 |
+
double alpha,
|
| 343 |
+
double beta) {
|
| 344 |
+
return Object::sendMessage<MatrixVectorMultiplication*>(
|
| 345 |
+
this,
|
| 346 |
+
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
|
| 347 |
+
device,
|
| 348 |
+
transpose,
|
| 349 |
+
rows,
|
| 350 |
+
columns,
|
| 351 |
+
alpha,
|
| 352 |
+
beta);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
|
| 356 |
+
MTL::CommandBuffer* commandBuffer,
|
| 357 |
+
Matrix* inputMatrix,
|
| 358 |
+
Vector* inputVector,
|
| 359 |
+
Vector* resultVector) {
|
| 360 |
+
return Object::sendMessage<void>(
|
| 361 |
+
this,
|
| 362 |
+
_MPS_PRIVATE_SEL(
|
| 363 |
+
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
|
| 364 |
+
commandBuffer,
|
| 365 |
+
inputMatrix,
|
| 366 |
+
inputVector,
|
| 367 |
+
resultVector);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
} // namespace MPS
|
lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "mlx/array.h"
|
| 6 |
+
#include "mlx/backend/metal/device.h"
|
| 7 |
+
|
| 8 |
+
namespace mlx::core {
|
| 9 |
+
|
| 10 |
+
namespace {
|
| 11 |
+
|
| 12 |
+
void set_array_buffer(
|
| 13 |
+
MTL::ComputeCommandEncoder* compute_encoder,
|
| 14 |
+
MTL::ArgumentEncoder* enc,
|
| 15 |
+
const array& a,
|
| 16 |
+
int idx) {
|
| 17 |
+
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
| 18 |
+
auto offset = a.data<char>() -
|
| 19 |
+
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
| 20 |
+
enc->setBuffer(a_buf, offset, idx);
|
| 21 |
+
// MTL::Resource usage through argument buffer needs to be explicitly
|
| 22 |
+
// flagged to enable hazard tracking
|
| 23 |
+
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
void set_array_buffer(
|
| 27 |
+
MTL::ComputeCommandEncoder* enc,
|
| 28 |
+
const array& a,
|
| 29 |
+
int idx) {
|
| 30 |
+
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
| 31 |
+
auto offset = a.data<char>() -
|
| 32 |
+
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
| 33 |
+
enc->setBuffer(a_buf, offset, idx);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
std::string type_to_name(const array& a) {
|
| 37 |
+
std::string tname;
|
| 38 |
+
switch (a.dtype()) {
|
| 39 |
+
case bool_:
|
| 40 |
+
tname = "bool_";
|
| 41 |
+
break;
|
| 42 |
+
case uint8:
|
| 43 |
+
tname = "uint8";
|
| 44 |
+
break;
|
| 45 |
+
case uint16:
|
| 46 |
+
tname = "uint16";
|
| 47 |
+
break;
|
| 48 |
+
case uint32:
|
| 49 |
+
tname = "uint32";
|
| 50 |
+
break;
|
| 51 |
+
case uint64:
|
| 52 |
+
tname = "uint64";
|
| 53 |
+
break;
|
| 54 |
+
case int8:
|
| 55 |
+
tname = "int8";
|
| 56 |
+
break;
|
| 57 |
+
case int16:
|
| 58 |
+
tname = "int16";
|
| 59 |
+
break;
|
| 60 |
+
case int32:
|
| 61 |
+
tname = "int32";
|
| 62 |
+
break;
|
| 63 |
+
case int64:
|
| 64 |
+
tname = "int64";
|
| 65 |
+
break;
|
| 66 |
+
case float16:
|
| 67 |
+
tname = "float16";
|
| 68 |
+
break;
|
| 69 |
+
case float32:
|
| 70 |
+
tname = "float32";
|
| 71 |
+
break;
|
| 72 |
+
case bfloat16:
|
| 73 |
+
tname = "bfloat16";
|
| 74 |
+
break;
|
| 75 |
+
case complex64:
|
| 76 |
+
tname = "complex64";
|
| 77 |
+
break;
|
| 78 |
+
}
|
| 79 |
+
return tname;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
| 83 |
+
int pows[3] = {0, 0, 0};
|
| 84 |
+
int sum = 0;
|
| 85 |
+
while (true) {
|
| 86 |
+
int presum = sum;
|
| 87 |
+
// Check all the pows
|
| 88 |
+
if (dim0 >= (1 << (pows[0] + 1))) {
|
| 89 |
+
pows[0]++;
|
| 90 |
+
sum++;
|
| 91 |
+
}
|
| 92 |
+
if (sum == 10) {
|
| 93 |
+
break;
|
| 94 |
+
}
|
| 95 |
+
if (dim1 >= (1 << (pows[1] + 1))) {
|
| 96 |
+
pows[1]++;
|
| 97 |
+
sum++;
|
| 98 |
+
}
|
| 99 |
+
if (sum == 10) {
|
| 100 |
+
break;
|
| 101 |
+
}
|
| 102 |
+
if (dim2 >= (1 << (pows[2] + 1))) {
|
| 103 |
+
pows[2]++;
|
| 104 |
+
sum++;
|
| 105 |
+
}
|
| 106 |
+
if (sum == presum || sum == 10) {
|
| 107 |
+
break;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
// Collapse dims that are contiguous to possibly route to a better kernel
|
| 114 |
+
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
| 115 |
+
// should return {{2, 4}, {{1, 2}}}.
|
| 116 |
+
//
|
| 117 |
+
// When multiple arrays are passed they should all have the same shape. The
|
| 118 |
+
// collapsed axes are also the same so one shape is returned.
|
| 119 |
+
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
| 120 |
+
collapse_contiguous_dims(const std::vector<array>& xs) {
|
| 121 |
+
// Make a vector that has axes separated with -1. Collapse all axes between
|
| 122 |
+
// -1.
|
| 123 |
+
std::vector<int> to_collapse;
|
| 124 |
+
if (xs[0].ndim() > 0) {
|
| 125 |
+
to_collapse.push_back(0);
|
| 126 |
+
for (int i = 1; i < xs[0].ndim(); i++) {
|
| 127 |
+
bool contiguous = true;
|
| 128 |
+
for (auto& x : xs) {
|
| 129 |
+
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
| 130 |
+
contiguous = false;
|
| 131 |
+
}
|
| 132 |
+
if (!contiguous) {
|
| 133 |
+
break;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
if (!contiguous) {
|
| 137 |
+
to_collapse.push_back(-1);
|
| 138 |
+
}
|
| 139 |
+
to_collapse.push_back(i);
|
| 140 |
+
}
|
| 141 |
+
to_collapse.push_back(-1);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
std::vector<int> out_shape;
|
| 145 |
+
std::vector<std::vector<size_t>> out_strides(xs.size());
|
| 146 |
+
for (int i = 0; i < to_collapse.size(); i++) {
|
| 147 |
+
int current_shape = xs[0].shape()[to_collapse[i]];
|
| 148 |
+
while (to_collapse[++i] != -1) {
|
| 149 |
+
current_shape *= xs[0].shape()[to_collapse[i]];
|
| 150 |
+
}
|
| 151 |
+
out_shape.push_back(current_shape);
|
| 152 |
+
for (int j = 0; j < xs.size(); j++) {
|
| 153 |
+
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
return std::make_tuple(out_shape, out_strides);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
template <typename... Arrays>
|
| 161 |
+
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
| 162 |
+
collapse_contiguous_dims(Arrays... xs) {
|
| 163 |
+
return collapse_contiguous_dims(
|
| 164 |
+
std::vector<array>{std::forward<Arrays>(xs)...});
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
} // namespace
|
| 168 |
+
|
| 169 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/device.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
namespace mlx::core {
|
| 6 |
+
|
| 7 |
+
struct Device {
|
| 8 |
+
enum class DeviceType {
|
| 9 |
+
cpu,
|
| 10 |
+
gpu,
|
| 11 |
+
};
|
| 12 |
+
|
| 13 |
+
static constexpr DeviceType cpu = DeviceType::cpu;
|
| 14 |
+
static constexpr DeviceType gpu = DeviceType::gpu;
|
| 15 |
+
|
| 16 |
+
Device(DeviceType type, int index = 0) : type(type), index(index){};
|
| 17 |
+
|
| 18 |
+
DeviceType type;
|
| 19 |
+
int index;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
const Device& default_device();
|
| 23 |
+
|
| 24 |
+
void set_default_device(const Device& d);
|
| 25 |
+
|
| 26 |
+
bool operator==(const Device& lhs, const Device& rhs);
|
| 27 |
+
bool operator!=(const Device& lhs, const Device& rhs);
|
| 28 |
+
|
| 29 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/dtype.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <complex>
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
#include <ostream>
|
| 8 |
+
#include <string>
|
| 9 |
+
|
| 10 |
+
#include "mlx/types/complex.h"
|
| 11 |
+
#include "mlx/types/half_types.h"
|
| 12 |
+
|
| 13 |
+
namespace mlx::core {
|
| 14 |
+
|
| 15 |
+
struct Dtype {
|
| 16 |
+
enum class Val {
|
| 17 |
+
bool_,
|
| 18 |
+
uint8,
|
| 19 |
+
uint16,
|
| 20 |
+
uint32,
|
| 21 |
+
uint64,
|
| 22 |
+
int8,
|
| 23 |
+
int16,
|
| 24 |
+
int32,
|
| 25 |
+
int64,
|
| 26 |
+
float16,
|
| 27 |
+
float32,
|
| 28 |
+
bfloat16,
|
| 29 |
+
complex64,
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
enum class Kind {
|
| 33 |
+
b, /* bool */
|
| 34 |
+
u, /* unsigned int */
|
| 35 |
+
i, /* signed int */
|
| 36 |
+
f, /* float */
|
| 37 |
+
c, /* complex */
|
| 38 |
+
V, /* void - used for brain float */
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
Val val;
|
| 42 |
+
const uint8_t size;
|
| 43 |
+
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
| 44 |
+
constexpr operator Val() const {
|
| 45 |
+
return val;
|
| 46 |
+
};
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
inline bool is_available(const Dtype& dtype) {
|
| 50 |
+
return true;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
| 54 |
+
|
| 55 |
+
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
| 56 |
+
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
| 57 |
+
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
| 58 |
+
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
| 59 |
+
|
| 60 |
+
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
| 61 |
+
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
| 62 |
+
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
| 63 |
+
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
| 64 |
+
|
| 65 |
+
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
| 66 |
+
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
| 67 |
+
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
| 68 |
+
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
| 69 |
+
|
| 70 |
+
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
| 71 |
+
|
| 72 |
+
inline uint8_t size_of(const Dtype& t) {
|
| 73 |
+
return t.size;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
Dtype::Kind kindof(const Dtype& t);
|
| 77 |
+
|
| 78 |
+
inline bool is_unsigned(const Dtype& t) {
|
| 79 |
+
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
inline bool is_floating_point(const Dtype& t) {
|
| 83 |
+
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
| 84 |
+
kindof(t) == Dtype::Kind::c;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline bool is_complex(const Dtype& t) {
|
| 88 |
+
return kindof(t) == Dtype::Kind::c;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
inline bool is_integral(const Dtype& t) {
|
| 92 |
+
return !(is_floating_point(t));
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template <typename T>
|
| 96 |
+
struct TypeToDtype {
|
| 97 |
+
operator Dtype();
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
// Array protocol typestring for Dtype
|
| 101 |
+
std::string dtype_to_array_protocol(const Dtype& t);
|
| 102 |
+
// Dtype from array protocol type string
|
| 103 |
+
Dtype dtype_from_array_protocol(const std::string& t);
|
| 104 |
+
|
| 105 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/fft.h
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <variant>
|
| 6 |
+
|
| 7 |
+
#include "array.h"
|
| 8 |
+
#include "device.h"
|
| 9 |
+
#include "stream.h"
|
| 10 |
+
|
| 11 |
+
namespace mlx::core::fft {
|
| 12 |
+
|
| 13 |
+
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
| 14 |
+
|
| 15 |
+
/** Compute the n-dimensional Fourier Transform. */
|
| 16 |
+
array fftn(
|
| 17 |
+
const array& a,
|
| 18 |
+
const std::vector<int>& n,
|
| 19 |
+
const std::vector<int>& axes,
|
| 20 |
+
StreamOrDevice s = {});
|
| 21 |
+
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
|
| 22 |
+
array fftn(const array& a, StreamOrDevice s = {});
|
| 23 |
+
|
| 24 |
+
/** Compute the n-dimensional inverse Fourier Transform. */
|
| 25 |
+
array ifftn(
|
| 26 |
+
const array& a,
|
| 27 |
+
const std::vector<int>& n,
|
| 28 |
+
const std::vector<int>& axes,
|
| 29 |
+
StreamOrDevice s = {});
|
| 30 |
+
array ifftn(
|
| 31 |
+
const array& a,
|
| 32 |
+
const std::vector<int>& axes,
|
| 33 |
+
StreamOrDevice s = {});
|
| 34 |
+
array ifftn(const array& a, StreamOrDevice s = {});
|
| 35 |
+
|
| 36 |
+
/** Compute the one-dimensional Fourier Transform. */
|
| 37 |
+
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
| 38 |
+
return fftn(a, {n}, {axis}, s);
|
| 39 |
+
}
|
| 40 |
+
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
| 41 |
+
return fftn(a, {axis}, s);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
/** Compute the one-dimensional inverse Fourier Transform. */
|
| 45 |
+
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
| 46 |
+
return ifftn(a, {n}, {axis}, s);
|
| 47 |
+
}
|
| 48 |
+
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
| 49 |
+
return ifftn(a, {axis}, s);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
/** Compute the two-dimensional Fourier Transform. */
|
| 53 |
+
inline array fft2(
|
| 54 |
+
const array& a,
|
| 55 |
+
const std::vector<int>& n,
|
| 56 |
+
const std::vector<int>& axes,
|
| 57 |
+
StreamOrDevice s = {}) {
|
| 58 |
+
return fftn(a, n, axes, s);
|
| 59 |
+
}
|
| 60 |
+
inline array fft2(
|
| 61 |
+
const array& a,
|
| 62 |
+
const std::vector<int>& axes = {-2, -1},
|
| 63 |
+
StreamOrDevice s = {}) {
|
| 64 |
+
return fftn(a, axes, s);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
/** Compute the two-dimensional inverse Fourier Transform. */
|
| 68 |
+
inline array ifft2(
|
| 69 |
+
const array& a,
|
| 70 |
+
const std::vector<int>& n,
|
| 71 |
+
const std::vector<int>& axes,
|
| 72 |
+
StreamOrDevice s = {}) {
|
| 73 |
+
return ifftn(a, n, axes, s);
|
| 74 |
+
}
|
| 75 |
+
inline array ifft2(
|
| 76 |
+
const array& a,
|
| 77 |
+
const std::vector<int>& axes = {-2, -1},
|
| 78 |
+
StreamOrDevice s = {}) {
|
| 79 |
+
return ifftn(a, axes, s);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/** Compute the n-dimensional Fourier Transform on a real input. */
|
| 83 |
+
array rfftn(
|
| 84 |
+
const array& a,
|
| 85 |
+
const std::vector<int>& n,
|
| 86 |
+
const std::vector<int>& axes,
|
| 87 |
+
StreamOrDevice s = {});
|
| 88 |
+
array rfftn(
|
| 89 |
+
const array& a,
|
| 90 |
+
const std::vector<int>& axes,
|
| 91 |
+
StreamOrDevice s = {});
|
| 92 |
+
array rfftn(const array& a, StreamOrDevice s = {});
|
| 93 |
+
|
| 94 |
+
/** Compute the n-dimensional inverse of `rfftn`. */
|
| 95 |
+
array irfftn(
|
| 96 |
+
const array& a,
|
| 97 |
+
const std::vector<int>& n,
|
| 98 |
+
const std::vector<int>& axes,
|
| 99 |
+
StreamOrDevice s = {});
|
| 100 |
+
array irfftn(
|
| 101 |
+
const array& a,
|
| 102 |
+
const std::vector<int>& axes,
|
| 103 |
+
StreamOrDevice s = {});
|
| 104 |
+
array irfftn(const array& a, StreamOrDevice s = {});
|
| 105 |
+
|
| 106 |
+
/** Compute the one-dimensional Fourier Transform on a real input. */
|
| 107 |
+
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
| 108 |
+
return rfftn(a, {n}, {axis}, s);
|
| 109 |
+
}
|
| 110 |
+
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
| 111 |
+
return rfftn(a, {axis}, s);
|
| 112 |
+
}
|
| 113 |
+
/** Compute the one-dimensional inverse of `rfft`. */
|
| 114 |
+
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
|
| 115 |
+
return irfftn(a, {n}, {axis}, s);
|
| 116 |
+
}
|
| 117 |
+
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
|
| 118 |
+
return irfftn(a, {axis}, s);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/** Compute the two-dimensional Fourier Transform on a real input. */
|
| 122 |
+
inline array rfft2(
|
| 123 |
+
const array& a,
|
| 124 |
+
const std::vector<int>& n,
|
| 125 |
+
const std::vector<int>& axes,
|
| 126 |
+
StreamOrDevice s = {}) {
|
| 127 |
+
return rfftn(a, n, axes, s);
|
| 128 |
+
}
|
| 129 |
+
inline array rfft2(
|
| 130 |
+
const array& a,
|
| 131 |
+
const std::vector<int>& axes = {-2, -1},
|
| 132 |
+
StreamOrDevice s = {}) {
|
| 133 |
+
return rfftn(a, axes, s);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/** Compute the two-dimensional inverse of `rfft2`. */
|
| 137 |
+
inline array irfft2(
|
| 138 |
+
const array& a,
|
| 139 |
+
const std::vector<int>& n,
|
| 140 |
+
const std::vector<int>& axes,
|
| 141 |
+
StreamOrDevice s = {}) {
|
| 142 |
+
return irfftn(a, n, axes, s);
|
| 143 |
+
}
|
| 144 |
+
inline array irfft2(
|
| 145 |
+
const array& a,
|
| 146 |
+
const std::vector<int>& axes = {-2, -1},
|
| 147 |
+
StreamOrDevice s = {}) {
|
| 148 |
+
return irfftn(a, axes, s);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
} // namespace mlx::core::fft
|
lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "mlx/array.h"
|
| 6 |
+
|
| 7 |
+
namespace mlx::core {
|
| 8 |
+
|
| 9 |
+
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
| 10 |
+
|
| 11 |
+
template <typename... Arrays>
|
| 12 |
+
void print_graph(std::ostream& os, Arrays... outputs) {
|
| 13 |
+
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
| 17 |
+
|
| 18 |
+
template <typename... Arrays>
|
| 19 |
+
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
| 20 |
+
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/io/load.h
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <fstream>
|
| 6 |
+
#include <istream>
|
| 7 |
+
#include <memory>
|
| 8 |
+
|
| 9 |
+
namespace mlx::core {
|
| 10 |
+
|
| 11 |
+
namespace io {
|
| 12 |
+
|
| 13 |
+
class Reader {
|
| 14 |
+
public:
|
| 15 |
+
virtual bool is_open() const = 0;
|
| 16 |
+
virtual bool good() const = 0;
|
| 17 |
+
virtual size_t tell() const = 0;
|
| 18 |
+
virtual void seek(
|
| 19 |
+
int64_t off,
|
| 20 |
+
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
| 21 |
+
virtual void read(char* data, size_t n) = 0;
|
| 22 |
+
virtual std::string label() const = 0;
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
class Writer {
|
| 26 |
+
public:
|
| 27 |
+
virtual bool is_open() const = 0;
|
| 28 |
+
virtual bool good() const = 0;
|
| 29 |
+
virtual size_t tell() const = 0;
|
| 30 |
+
virtual void seek(
|
| 31 |
+
int64_t off,
|
| 32 |
+
std::ios_base::seekdir way = std::ios_base::beg) = 0;
|
| 33 |
+
virtual void write(const char* data, size_t n) = 0;
|
| 34 |
+
virtual std::string label() const = 0;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
class FileReader : public Reader {
|
| 38 |
+
public:
|
| 39 |
+
explicit FileReader(const std::shared_ptr<std::ifstream>& is)
|
| 40 |
+
: is_(is), label_("stream") {}
|
| 41 |
+
explicit FileReader(const std::string& file_path)
|
| 42 |
+
: is_(std::make_shared<std::ifstream>(file_path, std::ios::binary)),
|
| 43 |
+
label_(file_path) {}
|
| 44 |
+
|
| 45 |
+
bool is_open() const override {
|
| 46 |
+
return is_->is_open();
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
bool good() const override {
|
| 50 |
+
return is_->good();
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
size_t tell() const override {
|
| 54 |
+
return is_->tellg();
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
| 58 |
+
override {
|
| 59 |
+
is_->seekg(off, way);
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void read(char* data, size_t n) override {
|
| 63 |
+
is_->read(data, n);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
std::string label() const override {
|
| 67 |
+
return "file " + label_;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
private:
|
| 71 |
+
std::shared_ptr<std::ifstream> is_;
|
| 72 |
+
std::string label_;
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
class FileWriter : public Writer {
|
| 76 |
+
public:
|
| 77 |
+
explicit FileWriter(const std::shared_ptr<std::ofstream>& is)
|
| 78 |
+
: os_(is), label_("stream") {}
|
| 79 |
+
explicit FileWriter(const std::string& file_path)
|
| 80 |
+
: os_(std::make_shared<std::ofstream>(file_path, std::ios::binary)),
|
| 81 |
+
label_(file_path) {}
|
| 82 |
+
|
| 83 |
+
bool is_open() const override {
|
| 84 |
+
return os_->is_open();
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
bool good() const override {
|
| 88 |
+
return os_->good();
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
size_t tell() const override {
|
| 92 |
+
return os_->tellp();
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
|
| 96 |
+
override {
|
| 97 |
+
os_->seekp(off, way);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
void write(const char* data, size_t n) override {
|
| 101 |
+
os_->write(data, n);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
std::string label() const override {
|
| 105 |
+
return "file " + label_;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
private:
|
| 109 |
+
std::shared_ptr<std::ofstream> os_;
|
| 110 |
+
std::string label_;
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
} // namespace io
|
| 114 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <json.hpp>
|
| 6 |
+
|
| 7 |
+
#include "mlx/io/load.h"
|
| 8 |
+
#include "mlx/ops.h"
|
| 9 |
+
#include "mlx/primitives.h"
|
| 10 |
+
|
| 11 |
+
using json = nlohmann::json;
|
| 12 |
+
|
| 13 |
+
namespace mlx::core {
|
| 14 |
+
|
| 15 |
+
#define ST_F16 "F16"
|
| 16 |
+
#define ST_BF16 "BF16"
|
| 17 |
+
#define ST_F32 "F32"
|
| 18 |
+
|
| 19 |
+
#define ST_BOOL "BOOL"
|
| 20 |
+
#define ST_I8 "I8"
|
| 21 |
+
#define ST_I16 "I16"
|
| 22 |
+
#define ST_I32 "I32"
|
| 23 |
+
#define ST_I64 "I64"
|
| 24 |
+
#define ST_U8 "U8"
|
| 25 |
+
#define ST_U16 "U16"
|
| 26 |
+
#define ST_U32 "U32"
|
| 27 |
+
#define ST_U64 "U64"
|
| 28 |
+
|
| 29 |
+
// Note: Complex numbers aren't in the spec yet so this could change -
|
| 30 |
+
// https://github.com/huggingface/safetensors/issues/389
|
| 31 |
+
#define ST_C64 "C64"
|
| 32 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/linalg.h
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <optional>
|
| 6 |
+
|
| 7 |
+
#include "mlx/array.h"
|
| 8 |
+
#include "mlx/device.h"
|
| 9 |
+
#include "mlx/ops.h"
|
| 10 |
+
#include "mlx/stream.h"
|
| 11 |
+
|
| 12 |
+
namespace mlx::core::linalg {
|
| 13 |
+
|
| 14 |
+
/**
|
| 15 |
+
* Compute vector or matrix norms.
|
| 16 |
+
*
|
| 17 |
+
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
| 18 |
+
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
| 19 |
+
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
| 20 |
+
* for matrices) is computed along the given axes. At most 2 axes can be
|
| 21 |
+
* specified.
|
| 22 |
+
* - If both axis and ord are provided, then the corresponding matrix or vector
|
| 23 |
+
* norm is computed. At most 2 axes can be specified.
|
| 24 |
+
*/
|
| 25 |
+
array norm(
|
| 26 |
+
const array& a,
|
| 27 |
+
const double ord,
|
| 28 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
| 29 |
+
bool keepdims = false,
|
| 30 |
+
StreamOrDevice s = {});
|
| 31 |
+
inline array norm(
|
| 32 |
+
const array& a,
|
| 33 |
+
const double ord,
|
| 34 |
+
int axis,
|
| 35 |
+
bool keepdims = false,
|
| 36 |
+
StreamOrDevice s = {}) {
|
| 37 |
+
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
| 38 |
+
}
|
| 39 |
+
array norm(
|
| 40 |
+
const array& a,
|
| 41 |
+
const std::string& ord,
|
| 42 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
| 43 |
+
bool keepdims = false,
|
| 44 |
+
StreamOrDevice s = {});
|
| 45 |
+
inline array norm(
|
| 46 |
+
const array& a,
|
| 47 |
+
const std::string& ord,
|
| 48 |
+
int axis,
|
| 49 |
+
bool keepdims = false,
|
| 50 |
+
StreamOrDevice s = {}) {
|
| 51 |
+
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
| 52 |
+
}
|
| 53 |
+
array norm(
|
| 54 |
+
const array& a,
|
| 55 |
+
const std::optional<std::vector<int>>& axis = std::nullopt,
|
| 56 |
+
bool keepdims = false,
|
| 57 |
+
StreamOrDevice s = {});
|
| 58 |
+
inline array
|
| 59 |
+
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
| 60 |
+
return norm(a, std::vector<int>{axis}, keepdims, s);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
} // namespace mlx::core::linalg
|
lib/python3.11/site-packages/mlx/include/mlx/mlx.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "mlx/array.h"
|
| 6 |
+
#include "mlx/backend/metal/metal.h"
|
| 7 |
+
#include "mlx/device.h"
|
| 8 |
+
#include "mlx/fft.h"
|
| 9 |
+
#include "mlx/linalg.h"
|
| 10 |
+
#include "mlx/ops.h"
|
| 11 |
+
#include "mlx/random.h"
|
| 12 |
+
#include "mlx/stream.h"
|
| 13 |
+
#include "mlx/transforms.h"
|
| 14 |
+
#include "mlx/utils.h"
|
lib/python3.11/site-packages/mlx/include/mlx/ops.h
ADDED
|
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <optional>
|
| 6 |
+
#include <variant>
|
| 7 |
+
|
| 8 |
+
#include "array.h"
|
| 9 |
+
#include "device.h"
|
| 10 |
+
#include "io/load.h"
|
| 11 |
+
#include "stream.h"
|
| 12 |
+
|
| 13 |
+
namespace mlx::core {
|
| 14 |
+
|
| 15 |
+
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
| 16 |
+
|
| 17 |
+
Stream to_stream(StreamOrDevice s);
|
| 18 |
+
|
| 19 |
+
/** Creation operations */
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* A 1D array of numbers starting at `start` (optional),
|
| 23 |
+
* stopping at stop, stepping by `step` (optional). */
|
| 24 |
+
array arange(
|
| 25 |
+
double start,
|
| 26 |
+
double stop,
|
| 27 |
+
double step,
|
| 28 |
+
Dtype dtype,
|
| 29 |
+
StreamOrDevice s = {});
|
| 30 |
+
array arange(double start, double stop, double step, StreamOrDevice s = {});
|
| 31 |
+
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
|
| 32 |
+
array arange(double start, double stop, StreamOrDevice s = {});
|
| 33 |
+
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
|
| 34 |
+
array arange(double stop, StreamOrDevice s = {});
|
| 35 |
+
|
| 36 |
+
array arange(int start, int stop, int step, StreamOrDevice s = {});
|
| 37 |
+
array arange(int start, int stop, StreamOrDevice s = {});
|
| 38 |
+
array arange(int stop, StreamOrDevice s = {});
|
| 39 |
+
|
| 40 |
+
/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
|
| 41 |
+
array linspace(
|
| 42 |
+
double start,
|
| 43 |
+
double stop,
|
| 44 |
+
int num = 50,
|
| 45 |
+
Dtype dtype = float32,
|
| 46 |
+
StreamOrDevice s = {});
|
| 47 |
+
|
| 48 |
+
/** Convert an array to the given data type. */
|
| 49 |
+
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
| 50 |
+
|
| 51 |
+
/** Create a view of an array with the given shape and strides. */
|
| 52 |
+
array as_strided(
|
| 53 |
+
const array& a,
|
| 54 |
+
std::vector<int> shape,
|
| 55 |
+
std::vector<size_t> strides,
|
| 56 |
+
size_t offset,
|
| 57 |
+
StreamOrDevice s = {});
|
| 58 |
+
|
| 59 |
+
/** Copy another array. */
|
| 60 |
+
array copy(const array& a, StreamOrDevice s = {});
|
| 61 |
+
|
| 62 |
+
/** Fill an array of the given shape with the given value(s). */
|
| 63 |
+
array full(
|
| 64 |
+
const std::vector<int>& shape,
|
| 65 |
+
const array& vals,
|
| 66 |
+
Dtype dtype,
|
| 67 |
+
StreamOrDevice s = {});
|
| 68 |
+
array full(
|
| 69 |
+
const std::vector<int>& shape,
|
| 70 |
+
const array& vals,
|
| 71 |
+
StreamOrDevice s = {});
|
| 72 |
+
template <typename T>
|
| 73 |
+
array full(
|
| 74 |
+
const std::vector<int>& shape,
|
| 75 |
+
T val,
|
| 76 |
+
Dtype dtype,
|
| 77 |
+
StreamOrDevice s = {}) {
|
| 78 |
+
return full(shape, array(val, dtype), to_stream(s));
|
| 79 |
+
}
|
| 80 |
+
template <typename T>
|
| 81 |
+
array full(const std::vector<int>& shape, T val, StreamOrDevice s = {}) {
|
| 82 |
+
return full(shape, array(val), to_stream(s));
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
/** Fill an array of the given shape with zeros. */
|
| 86 |
+
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
| 87 |
+
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
| 88 |
+
return zeros(shape, float32, s);
|
| 89 |
+
}
|
| 90 |
+
array zeros_like(const array& a, StreamOrDevice s = {});
|
| 91 |
+
|
| 92 |
+
/** Fill an array of the given shape with ones. */
|
| 93 |
+
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
|
| 94 |
+
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
|
| 95 |
+
return ones(shape, float32, s);
|
| 96 |
+
}
|
| 97 |
+
array ones_like(const array& a, StreamOrDevice s = {});
|
| 98 |
+
|
| 99 |
+
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
| 100 |
+
* k, and zeros everywhere else. */
|
| 101 |
+
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
| 102 |
+
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
| 103 |
+
return eye(n, n, 0, dtype, s);
|
| 104 |
+
}
|
| 105 |
+
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
| 106 |
+
return eye(n, m, 0, float32, s);
|
| 107 |
+
}
|
| 108 |
+
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
| 109 |
+
return eye(n, m, k, float32, s);
|
| 110 |
+
}
|
| 111 |
+
inline array eye(int n, StreamOrDevice s = {}) {
|
| 112 |
+
return eye(n, n, 0, float32, s);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
| 116 |
+
* diagonal. */
|
| 117 |
+
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
| 118 |
+
inline array identity(int n, StreamOrDevice s = {}) {
|
| 119 |
+
return identity(n, float32, s);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
|
| 123 |
+
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
| 124 |
+
return tri(n, n, 0, type, s);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
array tril(array x, int k, StreamOrDevice s = {});
|
| 128 |
+
array triu(array x, int k, StreamOrDevice s = {});
|
| 129 |
+
|
| 130 |
+
/** array manipulation */
|
| 131 |
+
|
| 132 |
+
/** Reshape an array to the given shape. */
|
| 133 |
+
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});
|
| 134 |
+
|
| 135 |
+
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
| 136 |
+
array flatten(
|
| 137 |
+
const array& a,
|
| 138 |
+
int start_axis,
|
| 139 |
+
int end_axis = -1,
|
| 140 |
+
StreamOrDevice s = {});
|
| 141 |
+
|
| 142 |
+
/** Flatten the array to 1D. */
|
| 143 |
+
array flatten(const array& a, StreamOrDevice s = {});
|
| 144 |
+
|
| 145 |
+
/** Remove singleton dimensions at the given axes. */
|
| 146 |
+
array squeeze(
|
| 147 |
+
const array& a,
|
| 148 |
+
const std::vector<int>& axes,
|
| 149 |
+
StreamOrDevice s = {});
|
| 150 |
+
|
| 151 |
+
/** Remove singleton dimensions at the given axis. */
|
| 152 |
+
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
|
| 153 |
+
return squeeze(a, std::vector<int>{axis}, s);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/** Remove all singleton dimensions. */
|
| 157 |
+
array squeeze(const array& a, StreamOrDevice s = {});
|
| 158 |
+
|
| 159 |
+
/** Add a singleton dimension at the given axes. */
|
| 160 |
+
array expand_dims(
|
| 161 |
+
const array& a,
|
| 162 |
+
const std::vector<int>& axes,
|
| 163 |
+
StreamOrDevice s = {});
|
| 164 |
+
|
| 165 |
+
/** Add a singleton dimension at the given axis. */
|
| 166 |
+
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
|
| 167 |
+
return expand_dims(a, std::vector<int>{axis}, s);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/** Slice an array. */
|
| 171 |
+
array slice(
|
| 172 |
+
const array& a,
|
| 173 |
+
std::vector<int> start,
|
| 174 |
+
std::vector<int> stop,
|
| 175 |
+
std::vector<int> strides,
|
| 176 |
+
StreamOrDevice s = {});
|
| 177 |
+
|
| 178 |
+
/** Slice an array with a stride of 1 in each dimension. */
|
| 179 |
+
array slice(
|
| 180 |
+
const array& a,
|
| 181 |
+
const std::vector<int>& start,
|
| 182 |
+
const std::vector<int>& stop,
|
| 183 |
+
StreamOrDevice s = {});
|
| 184 |
+
|
| 185 |
+
/** Split an array into sub-arrays along a given axis. */
|
| 186 |
+
std::vector<array>
|
| 187 |
+
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
| 188 |
+
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
| 189 |
+
std::vector<array> split(
|
| 190 |
+
const array& a,
|
| 191 |
+
const std::vector<int>& indices,
|
| 192 |
+
int axis,
|
| 193 |
+
StreamOrDevice s = {});
|
| 194 |
+
std::vector<array>
|
| 195 |
+
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
| 196 |
+
|
| 197 |
+
/**
|
| 198 |
+
* Clip (limit) the values in an array.
|
| 199 |
+
*/
|
| 200 |
+
array clip(
|
| 201 |
+
const array& a,
|
| 202 |
+
const std::optional<array>& a_min = std::nullopt,
|
| 203 |
+
const std::optional<array>& a_max = std::nullopt,
|
| 204 |
+
StreamOrDevice s = {});
|
| 205 |
+
|
| 206 |
+
/** Concatenate arrays along a given axis. */
|
| 207 |
+
array concatenate(
|
| 208 |
+
const std::vector<array>& arrays,
|
| 209 |
+
int axis,
|
| 210 |
+
StreamOrDevice s = {});
|
| 211 |
+
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});
|
| 212 |
+
|
| 213 |
+
/** Stack arrays along a new axis. */
|
| 214 |
+
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
| 215 |
+
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
| 216 |
+
|
| 217 |
+
/** Repeat an array along an axis. */
|
| 218 |
+
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
| 219 |
+
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
| 220 |
+
|
| 221 |
+
/** Permutes the dimensions according to the given axes. */
|
| 222 |
+
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
| 223 |
+
inline array transpose(
|
| 224 |
+
const array& a,
|
| 225 |
+
std::initializer_list<int> axes,
|
| 226 |
+
StreamOrDevice s = {}) {
|
| 227 |
+
return transpose(a, std::vector<int>(axes), s);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/** Swap two axes of an array. */
|
| 231 |
+
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
|
| 232 |
+
|
| 233 |
+
/** Move an axis of an array. */
|
| 234 |
+
array moveaxis(
|
| 235 |
+
const array& a,
|
| 236 |
+
int source,
|
| 237 |
+
int destination,
|
| 238 |
+
StreamOrDevice s = {});
|
| 239 |
+
|
| 240 |
+
/** Pad an array with a constant value */
|
| 241 |
+
array pad(
|
| 242 |
+
const array& a,
|
| 243 |
+
const std::vector<int>& axes,
|
| 244 |
+
const std::vector<int>& low_pad_size,
|
| 245 |
+
const std::vector<int>& high_pad_size,
|
| 246 |
+
const array& pad_value = array(0),
|
| 247 |
+
StreamOrDevice s = {});
|
| 248 |
+
|
| 249 |
+
/** Pad an array with a constant value along all axes */
|
| 250 |
+
array pad(
|
| 251 |
+
const array& a,
|
| 252 |
+
const std::vector<std::pair<int, int>>& pad_width,
|
| 253 |
+
const array& pad_value = array(0),
|
| 254 |
+
StreamOrDevice s = {});
|
| 255 |
+
array pad(
|
| 256 |
+
const array& a,
|
| 257 |
+
const std::pair<int, int>& pad_width,
|
| 258 |
+
const array& pad_value = array(0),
|
| 259 |
+
StreamOrDevice s = {});
|
| 260 |
+
array pad(
|
| 261 |
+
const array& a,
|
| 262 |
+
int pad_width,
|
| 263 |
+
const array& pad_value = array(0),
|
| 264 |
+
StreamOrDevice s = {});
|
| 265 |
+
|
| 266 |
+
/** Permutes the dimensions in reverse order. */
|
| 267 |
+
array transpose(const array& a, StreamOrDevice s = {});
|
| 268 |
+
|
| 269 |
+
/** Broadcast an array to a given shape. */
|
| 270 |
+
array broadcast_to(
|
| 271 |
+
const array& a,
|
| 272 |
+
const std::vector<int>& shape,
|
| 273 |
+
StreamOrDevice s = {});
|
| 274 |
+
|
| 275 |
+
/** Broadcast a vector of arrays against one another. */
|
| 276 |
+
std::vector<array> broadcast_arrays(
|
| 277 |
+
const std::vector<array>& inputs,
|
| 278 |
+
StreamOrDevice s = {});
|
| 279 |
+
|
| 280 |
+
/** Comparison operations */
|
| 281 |
+
|
| 282 |
+
/** Returns the bool array with (a == b) element-wise. */
|
| 283 |
+
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
| 284 |
+
inline array operator==(const array& a, const array& b) {
|
| 285 |
+
return equal(a, b);
|
| 286 |
+
}
|
| 287 |
+
template <typename T>
|
| 288 |
+
array operator==(T a, const array& b) {
|
| 289 |
+
return equal(array(a), b);
|
| 290 |
+
}
|
| 291 |
+
template <typename T>
|
| 292 |
+
array operator==(const array& a, T b) {
|
| 293 |
+
return equal(a, array(b));
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
/** Returns the bool array with (a != b) element-wise. */
|
| 297 |
+
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
|
| 298 |
+
inline array operator!=(const array& a, const array& b) {
|
| 299 |
+
return not_equal(a, b);
|
| 300 |
+
}
|
| 301 |
+
template <typename T>
|
| 302 |
+
array operator!=(T a, const array& b) {
|
| 303 |
+
return not_equal(array(a), b);
|
| 304 |
+
}
|
| 305 |
+
template <typename T>
|
| 306 |
+
array operator!=(const array& a, T b) {
|
| 307 |
+
return not_equal(a, array(b));
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
/** Returns bool array with (a > b) element-wise. */
|
| 311 |
+
array greater(const array& a, const array& b, StreamOrDevice s = {});
|
| 312 |
+
inline array operator>(const array& a, const array& b) {
|
| 313 |
+
return greater(a, b);
|
| 314 |
+
}
|
| 315 |
+
template <typename T>
|
| 316 |
+
array operator>(T a, const array& b) {
|
| 317 |
+
return greater(array(a), b);
|
| 318 |
+
}
|
| 319 |
+
template <typename T>
|
| 320 |
+
array operator>(const array& a, T b) {
|
| 321 |
+
return greater(a, array(b));
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/** Returns bool array with (a >= b) element-wise. */
|
| 325 |
+
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
|
| 326 |
+
inline array operator>=(const array& a, const array& b) {
|
| 327 |
+
return greater_equal(a, b);
|
| 328 |
+
}
|
| 329 |
+
template <typename T>
|
| 330 |
+
array operator>=(T a, const array& b) {
|
| 331 |
+
return greater_equal(array(a), b);
|
| 332 |
+
}
|
| 333 |
+
template <typename T>
|
| 334 |
+
array operator>=(const array& a, T b) {
|
| 335 |
+
return greater_equal(a, array(b));
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
/** Returns bool array with (a < b) element-wise. */
|
| 339 |
+
array less(const array& a, const array& b, StreamOrDevice s = {});
|
| 340 |
+
inline array operator<(const array& a, const array& b) {
|
| 341 |
+
return less(a, b);
|
| 342 |
+
}
|
| 343 |
+
template <typename T>
|
| 344 |
+
array operator<(T a, const array& b) {
|
| 345 |
+
return less(array(a), b);
|
| 346 |
+
}
|
| 347 |
+
template <typename T>
|
| 348 |
+
array operator<(const array& a, T b) {
|
| 349 |
+
return less(a, array(b));
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/** Returns bool array with (a <= b) element-wise. */
|
| 353 |
+
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
|
| 354 |
+
inline array operator<=(const array& a, const array& b) {
|
| 355 |
+
return less_equal(a, b);
|
| 356 |
+
}
|
| 357 |
+
template <typename T>
|
| 358 |
+
array operator<=(T a, const array& b) {
|
| 359 |
+
return less_equal(array(a), b);
|
| 360 |
+
}
|
| 361 |
+
template <typename T>
|
| 362 |
+
array operator<=(const array& a, T b) {
|
| 363 |
+
return less_equal(a, array(b));
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
/** True if two arrays have the same shape and elements. */
|
| 367 |
+
array array_equal(
|
| 368 |
+
const array& a,
|
| 369 |
+
const array& b,
|
| 370 |
+
bool equal_nan,
|
| 371 |
+
StreamOrDevice s = {});
|
| 372 |
+
inline array
|
| 373 |
+
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
|
| 374 |
+
return array_equal(a, b, false, s);
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
/** Select from x or y depending on condition. */
|
| 378 |
+
array where(
|
| 379 |
+
const array& condition,
|
| 380 |
+
const array& x,
|
| 381 |
+
const array& y,
|
| 382 |
+
StreamOrDevice s = {});
|
| 383 |
+
|
| 384 |
+
/** Reduction operations */
|
| 385 |
+
|
| 386 |
+
/** True if all elements in the array are true (or non-zero). **/
|
| 387 |
+
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 388 |
+
inline array all(const array& a, StreamOrDevice s = {}) {
|
| 389 |
+
return all(a, false, to_stream(s));
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/** True if the two arrays are equal within the specified tolerance. */
|
| 393 |
+
array allclose(
|
| 394 |
+
const array& a,
|
| 395 |
+
const array& b,
|
| 396 |
+
double rtol = 1e-5,
|
| 397 |
+
double atol = 1e-8,
|
| 398 |
+
StreamOrDevice s = {});
|
| 399 |
+
|
| 400 |
+
/**
|
| 401 |
+
* Reduces the input along the given axes. An output value is true
|
| 402 |
+
* if all the corresponding inputs are true.
|
| 403 |
+
**/
|
| 404 |
+
array all(
|
| 405 |
+
const array& a,
|
| 406 |
+
const std::vector<int>& axes,
|
| 407 |
+
bool keepdims = false,
|
| 408 |
+
StreamOrDevice s = {});
|
| 409 |
+
|
| 410 |
+
/**
|
| 411 |
+
* Reduces the input along the given axis. An output value is true
|
| 412 |
+
* if all the corresponding inputs are true.
|
| 413 |
+
**/
|
| 414 |
+
array all(
|
| 415 |
+
const array& a,
|
| 416 |
+
int axis,
|
| 417 |
+
bool keepdims = false,
|
| 418 |
+
StreamOrDevice s = {});
|
| 419 |
+
|
| 420 |
+
/** True if any elements in the array are true (or non-zero). **/
|
| 421 |
+
array any(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 422 |
+
inline array any(const array& a, StreamOrDevice s = {}) {
|
| 423 |
+
return any(a, false, to_stream(s));
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/**
|
| 427 |
+
* Reduces the input along the given axes. An output value is true
|
| 428 |
+
* if any of the corresponding inputs are true.
|
| 429 |
+
**/
|
| 430 |
+
array any(
|
| 431 |
+
const array& a,
|
| 432 |
+
const std::vector<int>& axes,
|
| 433 |
+
bool keepdims = false,
|
| 434 |
+
StreamOrDevice s = {});
|
| 435 |
+
|
| 436 |
+
/**
|
| 437 |
+
* Reduces the input along the given axis. An output value is true
|
| 438 |
+
* if any of the corresponding inputs are true.
|
| 439 |
+
**/
|
| 440 |
+
array any(
|
| 441 |
+
const array& a,
|
| 442 |
+
int axis,
|
| 443 |
+
bool keepdims = false,
|
| 444 |
+
StreamOrDevice s = {});
|
| 445 |
+
|
| 446 |
+
/** Sums the elements of an array. */
|
| 447 |
+
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 448 |
+
inline array sum(const array& a, StreamOrDevice s = {}) {
|
| 449 |
+
return sum(a, false, to_stream(s));
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
/** Sums the elements of an array along the given axes. */
|
| 453 |
+
array sum(
|
| 454 |
+
const array& a,
|
| 455 |
+
const std::vector<int>& axes,
|
| 456 |
+
bool keepdims = false,
|
| 457 |
+
StreamOrDevice s = {});
|
| 458 |
+
|
| 459 |
+
/** Sums the elements of an array along the given axis. */
|
| 460 |
+
array sum(
|
| 461 |
+
const array& a,
|
| 462 |
+
int axis,
|
| 463 |
+
bool keepdims = false,
|
| 464 |
+
StreamOrDevice s = {});
|
| 465 |
+
|
| 466 |
+
/** Computes the mean of the elements of an array. */
|
| 467 |
+
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 468 |
+
inline array mean(const array& a, StreamOrDevice s = {}) {
|
| 469 |
+
return mean(a, false, to_stream(s));
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
/** Computes the mean of the elements of an array along the given axes */
|
| 473 |
+
array mean(
|
| 474 |
+
const array& a,
|
| 475 |
+
const std::vector<int>& axes,
|
| 476 |
+
bool keepdims = false,
|
| 477 |
+
StreamOrDevice s = {});
|
| 478 |
+
|
| 479 |
+
/** Computes the mean of the elements of an array along the given axis */
|
| 480 |
+
array mean(
|
| 481 |
+
const array& a,
|
| 482 |
+
int axis,
|
| 483 |
+
bool keepdims = false,
|
| 484 |
+
StreamOrDevice s = {});
|
| 485 |
+
|
| 486 |
+
/** Computes the mean of the elements of an array. */
|
| 487 |
+
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
| 488 |
+
inline array var(const array& a, StreamOrDevice s = {}) {
|
| 489 |
+
return var(a, false, 0, to_stream(s));
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
/** Computes the var of the elements of an array along the given axes */
|
| 493 |
+
array var(
|
| 494 |
+
const array& a,
|
| 495 |
+
const std::vector<int>& axes,
|
| 496 |
+
bool keepdims = false,
|
| 497 |
+
int ddof = 0,
|
| 498 |
+
StreamOrDevice s = {});
|
| 499 |
+
|
| 500 |
+
/** Computes the var of the elements of an array along the given axis */
|
| 501 |
+
array var(
|
| 502 |
+
const array& a,
|
| 503 |
+
int axis,
|
| 504 |
+
bool keepdims = false,
|
| 505 |
+
int ddof = 0,
|
| 506 |
+
StreamOrDevice s = {});
|
| 507 |
+
|
| 508 |
+
/** The product of all elements of the array. */
|
| 509 |
+
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 510 |
+
inline array prod(const array& a, StreamOrDevice s = {}) {
|
| 511 |
+
return prod(a, false, to_stream(s));
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
/** The product of the elements of an array along the given axes. */
|
| 515 |
+
array prod(
|
| 516 |
+
const array& a,
|
| 517 |
+
const std::vector<int>& axes,
|
| 518 |
+
bool keepdims = false,
|
| 519 |
+
StreamOrDevice s = {});
|
| 520 |
+
|
| 521 |
+
/** The product of the elements of an array along the given axis. */
|
| 522 |
+
array prod(
|
| 523 |
+
const array& a,
|
| 524 |
+
int axis,
|
| 525 |
+
bool keepdims = false,
|
| 526 |
+
StreamOrDevice s = {});
|
| 527 |
+
|
| 528 |
+
/** The maximum of all elements of the array. */
|
| 529 |
+
array max(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 530 |
+
inline array max(const array& a, StreamOrDevice s = {}) {
|
| 531 |
+
return max(a, false, to_stream(s));
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
/** The maximum of the elements of an array along the given axes. */
|
| 535 |
+
array max(
|
| 536 |
+
const array& a,
|
| 537 |
+
const std::vector<int>& axes,
|
| 538 |
+
bool keepdims = false,
|
| 539 |
+
StreamOrDevice s = {});
|
| 540 |
+
|
| 541 |
+
/** The maximum of the elements of an array along the given axis. */
|
| 542 |
+
array max(
|
| 543 |
+
const array& a,
|
| 544 |
+
int axis,
|
| 545 |
+
bool keepdims = false,
|
| 546 |
+
StreamOrDevice s = {});
|
| 547 |
+
|
| 548 |
+
/** The minimum of all elements of the array. */
|
| 549 |
+
array min(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 550 |
+
inline array min(const array& a, StreamOrDevice s = {}) {
|
| 551 |
+
return min(a, false, to_stream(s));
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
/** The minimum of the elements of an array along the given axes. */
|
| 555 |
+
array min(
|
| 556 |
+
const array& a,
|
| 557 |
+
const std::vector<int>& axes,
|
| 558 |
+
bool keepdims = false,
|
| 559 |
+
StreamOrDevice s = {});
|
| 560 |
+
|
| 561 |
+
/** The minimum of the elements of an array along the given axis. */
|
| 562 |
+
array min(
|
| 563 |
+
const array& a,
|
| 564 |
+
int axis,
|
| 565 |
+
bool keepdims = false,
|
| 566 |
+
StreamOrDevice s = {});
|
| 567 |
+
|
| 568 |
+
/** Returns the index of the minimum value in the array. */
|
| 569 |
+
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 570 |
+
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
| 571 |
+
return argmin(a, false, s);
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
/** Returns the indices of the minimum values along a given axis. */
|
| 575 |
+
array argmin(
|
| 576 |
+
const array& a,
|
| 577 |
+
int axis,
|
| 578 |
+
bool keepdims = false,
|
| 579 |
+
StreamOrDevice s = {});
|
| 580 |
+
|
| 581 |
+
/** Returns the index of the maximum value in the array. */
|
| 582 |
+
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 583 |
+
inline array argmax(const array& a, StreamOrDevice s = {}) {
|
| 584 |
+
return argmax(a, false, s);
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
/** Returns the indices of the maximum values along a given axis. */
|
| 588 |
+
array argmax(
|
| 589 |
+
const array& a,
|
| 590 |
+
int axis,
|
| 591 |
+
bool keepdims = false,
|
| 592 |
+
StreamOrDevice s = {});
|
| 593 |
+
|
| 594 |
+
/** Returns a sorted copy of the flattened array. */
|
| 595 |
+
array sort(const array& a, StreamOrDevice s = {});
|
| 596 |
+
|
| 597 |
+
/** Returns a sorted copy of the array along a given axis. */
|
| 598 |
+
array sort(const array& a, int axis, StreamOrDevice s = {});
|
| 599 |
+
|
| 600 |
+
/** Returns indices that sort the flattened array. */
|
| 601 |
+
array argsort(const array& a, StreamOrDevice s = {});
|
| 602 |
+
|
| 603 |
+
/** Returns indices that sort the array along a given axis. */
|
| 604 |
+
array argsort(const array& a, int axis, StreamOrDevice s = {});
|
| 605 |
+
|
| 606 |
+
/**
|
| 607 |
+
* Returns a partitioned copy of the flattened array
|
| 608 |
+
* such that the smaller kth elements are first.
|
| 609 |
+
**/
|
| 610 |
+
array partition(const array& a, int kth, StreamOrDevice s = {});
|
| 611 |
+
|
| 612 |
+
/**
|
| 613 |
+
* Returns a partitioned copy of the array along a given axis
|
| 614 |
+
* such that the smaller kth elements are first.
|
| 615 |
+
**/
|
| 616 |
+
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
| 617 |
+
|
| 618 |
+
/**
|
| 619 |
+
* Returns indices that partition the flattened array
|
| 620 |
+
* such that the smaller kth elements are first.
|
| 621 |
+
**/
|
| 622 |
+
array argpartition(const array& a, int kth, StreamOrDevice s = {});
|
| 623 |
+
|
| 624 |
+
/**
|
| 625 |
+
* Returns indices that partition the array along a given axis
|
| 626 |
+
* such that the smaller kth elements are first.
|
| 627 |
+
**/
|
| 628 |
+
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
| 629 |
+
|
| 630 |
+
/** Returns topk elements of the flattened array. */
|
| 631 |
+
array topk(const array& a, int k, StreamOrDevice s = {});
|
| 632 |
+
|
| 633 |
+
/** Returns topk elements of the array along a given axis. */
|
| 634 |
+
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
| 635 |
+
|
| 636 |
+
/** The logsumexp of all elements of the array. */
|
| 637 |
+
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
| 638 |
+
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
| 639 |
+
return logsumexp(a, false, to_stream(s));
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
/** The logsumexp of the elements of an array along the given axes. */
|
| 643 |
+
array logsumexp(
|
| 644 |
+
const array& a,
|
| 645 |
+
const std::vector<int>& axes,
|
| 646 |
+
bool keepdims = false,
|
| 647 |
+
StreamOrDevice s = {});
|
| 648 |
+
|
| 649 |
+
/** The logsumexp of the elements of an array along the given axis. */
|
| 650 |
+
array logsumexp(
|
| 651 |
+
const array& a,
|
| 652 |
+
int axis,
|
| 653 |
+
bool keepdims = false,
|
| 654 |
+
StreamOrDevice s = {});
|
| 655 |
+
|
| 656 |
+
/** Simple arithmetic operations */
|
| 657 |
+
|
| 658 |
+
/** Absolute value of elements in an array. */
|
| 659 |
+
array abs(const array& a, StreamOrDevice s = {});
|
| 660 |
+
|
| 661 |
+
/** Negate an array. */
|
| 662 |
+
array negative(const array& a, StreamOrDevice s = {});
|
| 663 |
+
array operator-(const array& a);
|
| 664 |
+
|
| 665 |
+
/** The sign of the elements in an array. */
|
| 666 |
+
array sign(const array& a, StreamOrDevice s = {});
|
| 667 |
+
|
| 668 |
+
/** Logical not of an array */
|
| 669 |
+
array logical_not(const array& a, StreamOrDevice s = {});
|
| 670 |
+
|
| 671 |
+
/** The reciprocal (1/x) of the elements in an array. */
|
| 672 |
+
array reciprocal(const array& a, StreamOrDevice s = {});
|
| 673 |
+
|
| 674 |
+
/** Add two arrays. */
|
| 675 |
+
array add(const array& a, const array& b, StreamOrDevice s = {});
|
| 676 |
+
array operator+(const array& a, const array& b);
|
| 677 |
+
template <typename T>
|
| 678 |
+
array operator+(T a, const array& b) {
|
| 679 |
+
return add(array(a), b);
|
| 680 |
+
}
|
| 681 |
+
template <typename T>
|
| 682 |
+
array operator+(const array& a, T b) {
|
| 683 |
+
return add(a, array(b));
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
/** Subtract two arrays. */
|
| 687 |
+
array subtract(const array& a, const array& b, StreamOrDevice s = {});
|
| 688 |
+
array operator-(const array& a, const array& b);
|
| 689 |
+
template <typename T>
|
| 690 |
+
array operator-(T a, const array& b) {
|
| 691 |
+
return subtract(array(a), b);
|
| 692 |
+
}
|
| 693 |
+
template <typename T>
|
| 694 |
+
array operator-(const array& a, T b) {
|
| 695 |
+
return subtract(a, array(b));
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
/** Multiply two arrays. */
|
| 699 |
+
array multiply(const array& a, const array& b, StreamOrDevice s = {});
|
| 700 |
+
array operator*(const array& a, const array& b);
|
| 701 |
+
template <typename T>
|
| 702 |
+
array operator*(T a, const array& b) {
|
| 703 |
+
return multiply(array(a), b);
|
| 704 |
+
}
|
| 705 |
+
template <typename T>
|
| 706 |
+
array operator*(const array& a, T b) {
|
| 707 |
+
return multiply(a, array(b));
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
/** Divide two arrays. */
|
| 711 |
+
array divide(const array& a, const array& b, StreamOrDevice s = {});
|
| 712 |
+
array operator/(const array& a, const array& b);
|
| 713 |
+
array operator/(double a, const array& b);
|
| 714 |
+
array operator/(const array& a, double b);
|
| 715 |
+
|
| 716 |
+
/** Compute integer division. Equivalent to doing floor(a / x). */
|
| 717 |
+
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
|
| 718 |
+
|
| 719 |
+
/** Compute the element-wise remainder of division */
|
| 720 |
+
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
| 721 |
+
array operator%(const array& a, const array& b);
|
| 722 |
+
template <typename T>
|
| 723 |
+
array operator%(T a, const array& b) {
|
| 724 |
+
return remainder(array(a), b);
|
| 725 |
+
}
|
| 726 |
+
template <typename T>
|
| 727 |
+
array operator%(const array& a, T b) {
|
| 728 |
+
return remainder(a, array(b));
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
/** Element-wise maximum between two arrays. */
|
| 732 |
+
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
| 733 |
+
|
| 734 |
+
/** Element-wise minimum between two arrays. */
|
| 735 |
+
array minimum(const array& a, const array& b, StreamOrDevice s = {});
|
| 736 |
+
|
| 737 |
+
/** Floor the element of an array. **/
|
| 738 |
+
array floor(const array& a, StreamOrDevice s = {});
|
| 739 |
+
|
| 740 |
+
/** Ceil the element of an array. **/
|
| 741 |
+
array ceil(const array& a, StreamOrDevice s = {});
|
| 742 |
+
|
| 743 |
+
/** Square the elements of an array. */
|
| 744 |
+
array square(const array& a, StreamOrDevice s = {});
|
| 745 |
+
|
| 746 |
+
/** Exponential of the elements of an array. */
|
| 747 |
+
array exp(const array& a, StreamOrDevice s = {});
|
| 748 |
+
|
| 749 |
+
/** Sine of the elements of an array */
|
| 750 |
+
array sin(const array& a, StreamOrDevice s = {});
|
| 751 |
+
|
| 752 |
+
/** Cosine of the elements of an array */
|
| 753 |
+
array cos(const array& a, StreamOrDevice s = {});
|
| 754 |
+
|
| 755 |
+
/** Tangent of the elements of an array */
|
| 756 |
+
array tan(const array& a, StreamOrDevice s = {});
|
| 757 |
+
|
| 758 |
+
/** Arc Sine of the elements of an array */
|
| 759 |
+
array arcsin(const array& a, StreamOrDevice s = {});
|
| 760 |
+
|
| 761 |
+
/** Arc Cosine of the elements of an array */
|
| 762 |
+
array arccos(const array& a, StreamOrDevice s = {});
|
| 763 |
+
|
| 764 |
+
/** Arc Tangent of the elements of an array */
|
| 765 |
+
array arctan(const array& a, StreamOrDevice s = {});
|
| 766 |
+
|
| 767 |
+
/** Hyperbolic Sine of the elements of an array */
|
| 768 |
+
array sinh(const array& a, StreamOrDevice s = {});
|
| 769 |
+
|
| 770 |
+
/** Hyperbolic Cosine of the elements of an array */
|
| 771 |
+
array cosh(const array& a, StreamOrDevice s = {});
|
| 772 |
+
|
| 773 |
+
/** Hyperbolic Tangent of the elements of an array */
|
| 774 |
+
array tanh(const array& a, StreamOrDevice s = {});
|
| 775 |
+
|
| 776 |
+
/** Inverse Hyperbolic Sine of the elements of an array */
|
| 777 |
+
array arcsinh(const array& a, StreamOrDevice s = {});
|
| 778 |
+
|
| 779 |
+
/** Inverse Hyperbolic Cosine of the elements of an array */
|
| 780 |
+
array arccosh(const array& a, StreamOrDevice s = {});
|
| 781 |
+
|
| 782 |
+
/** Inverse Hyperbolic Tangent of the elements of an array */
|
| 783 |
+
array arctanh(const array& a, StreamOrDevice s = {});
|
| 784 |
+
|
| 785 |
+
/** Natural logarithm of the elements of an array. */
|
| 786 |
+
array log(const array& a, StreamOrDevice s = {});
|
| 787 |
+
|
| 788 |
+
/** Log base 2 of the elements of an array. */
|
| 789 |
+
array log2(const array& a, StreamOrDevice s = {});
|
| 790 |
+
|
| 791 |
+
/** Log base 10 of the elements of an array. */
|
| 792 |
+
array log10(const array& a, StreamOrDevice s = {});
|
| 793 |
+
|
| 794 |
+
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
|
| 795 |
+
array log1p(const array& a, StreamOrDevice s = {});
|
| 796 |
+
|
| 797 |
+
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
|
| 798 |
+
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
|
| 799 |
+
|
| 800 |
+
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
|
| 801 |
+
array sigmoid(const array& a, StreamOrDevice s = {});
|
| 802 |
+
|
| 803 |
+
/** Computes the error function of the elements of an array. */
|
| 804 |
+
array erf(const array& a, StreamOrDevice s = {});
|
| 805 |
+
|
| 806 |
+
/** Computes the inverse error function of the elements of an array. */
|
| 807 |
+
array erfinv(const array& a, StreamOrDevice s = {});
|
| 808 |
+
|
| 809 |
+
/** Stop the flow of gradients. */
|
| 810 |
+
array stop_gradient(const array& a, StreamOrDevice s = {});
|
| 811 |
+
|
| 812 |
+
/** Round a floating point number */
|
| 813 |
+
array round(const array& a, int decimals, StreamOrDevice s = {});
|
| 814 |
+
inline array round(const array& a, StreamOrDevice s = {}) {
|
| 815 |
+
return round(a, 0, s);
|
| 816 |
+
}
|
| 817 |
+
|
| 818 |
+
/** Matrix-matrix multiplication. */
|
| 819 |
+
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
| 820 |
+
|
| 821 |
+
/** Gather array entries given indices and slices */
|
| 822 |
+
array gather(
|
| 823 |
+
const array& a,
|
| 824 |
+
const std::vector<array>& indices,
|
| 825 |
+
const std::vector<int>& axes,
|
| 826 |
+
const std::vector<int>& slice_sizes,
|
| 827 |
+
StreamOrDevice s = {});
|
| 828 |
+
inline array gather(
|
| 829 |
+
const array& a,
|
| 830 |
+
const array& indices,
|
| 831 |
+
int axis,
|
| 832 |
+
const std::vector<int>& slice_sizes,
|
| 833 |
+
StreamOrDevice s = {}) {
|
| 834 |
+
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
/** Take array slices at the given indices of the specified axis. */
|
| 838 |
+
array take(
|
| 839 |
+
const array& a,
|
| 840 |
+
const array& indices,
|
| 841 |
+
int axis,
|
| 842 |
+
StreamOrDevice s = {});
|
| 843 |
+
|
| 844 |
+
/** Take array entries at the given indices treating the array as flattened. */
|
| 845 |
+
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
| 846 |
+
|
| 847 |
+
/** Take array entries given indices along the axis */
|
| 848 |
+
array take_along_axis(
|
| 849 |
+
const array& a,
|
| 850 |
+
const array& indices,
|
| 851 |
+
int axis,
|
| 852 |
+
StreamOrDevice s = {});
|
| 853 |
+
|
| 854 |
+
/** Scatter updates to given linear indices */
|
| 855 |
+
array scatter(
|
| 856 |
+
const array& a,
|
| 857 |
+
const std::vector<array>& indices,
|
| 858 |
+
const array& updates,
|
| 859 |
+
const std::vector<int>& axes,
|
| 860 |
+
StreamOrDevice s = {});
|
| 861 |
+
inline array scatter(
|
| 862 |
+
const array& a,
|
| 863 |
+
const array& indices,
|
| 864 |
+
const array& updates,
|
| 865 |
+
int axis,
|
| 866 |
+
StreamOrDevice s = {}) {
|
| 867 |
+
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
/** Scatter and add updates to given indices */
|
| 871 |
+
array scatter_add(
|
| 872 |
+
const array& a,
|
| 873 |
+
const std::vector<array>& indices,
|
| 874 |
+
const array& updates,
|
| 875 |
+
const std::vector<int>& axes,
|
| 876 |
+
StreamOrDevice s = {});
|
| 877 |
+
inline array scatter_add(
|
| 878 |
+
const array& a,
|
| 879 |
+
const array& indices,
|
| 880 |
+
const array& updates,
|
| 881 |
+
int axis,
|
| 882 |
+
StreamOrDevice s = {}) {
|
| 883 |
+
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
/** Scatter and prod updates to given indices */
|
| 887 |
+
array scatter_prod(
|
| 888 |
+
const array& a,
|
| 889 |
+
const std::vector<array>& indices,
|
| 890 |
+
const array& updates,
|
| 891 |
+
const std::vector<int>& axes,
|
| 892 |
+
StreamOrDevice s = {});
|
| 893 |
+
inline array scatter_prod(
|
| 894 |
+
const array& a,
|
| 895 |
+
const array& indices,
|
| 896 |
+
const array& updates,
|
| 897 |
+
int axis,
|
| 898 |
+
StreamOrDevice s = {}) {
|
| 899 |
+
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
/** Scatter and max updates to given linear indices */
|
| 903 |
+
array scatter_max(
|
| 904 |
+
const array& a,
|
| 905 |
+
const std::vector<array>& indices,
|
| 906 |
+
const array& updates,
|
| 907 |
+
const std::vector<int>& axes,
|
| 908 |
+
StreamOrDevice s = {});
|
| 909 |
+
inline array scatter_max(
|
| 910 |
+
const array& a,
|
| 911 |
+
const array& indices,
|
| 912 |
+
const array& updates,
|
| 913 |
+
int axis,
|
| 914 |
+
StreamOrDevice s = {}) {
|
| 915 |
+
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
|
| 916 |
+
}
|
| 917 |
+
/** Scatter and min updates to given linear indices */
|
| 918 |
+
array scatter_min(
|
| 919 |
+
const array& a,
|
| 920 |
+
const std::vector<array>& indices,
|
| 921 |
+
const array& updates,
|
| 922 |
+
const std::vector<int>& axes,
|
| 923 |
+
StreamOrDevice s = {});
|
| 924 |
+
inline array scatter_min(
|
| 925 |
+
const array& a,
|
| 926 |
+
const array& indices,
|
| 927 |
+
const array& updates,
|
| 928 |
+
int axis,
|
| 929 |
+
StreamOrDevice s = {}) {
|
| 930 |
+
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
/** Square root the elements of an array. */
|
| 934 |
+
array sqrt(const array& a, StreamOrDevice s = {});
|
| 935 |
+
|
| 936 |
+
/** Square root and reciprocal the elements of an array. */
|
| 937 |
+
array rsqrt(const array& a, StreamOrDevice s = {});
|
| 938 |
+
|
| 939 |
+
/** Softmax of an array. */
|
| 940 |
+
array softmax(
|
| 941 |
+
const array& a,
|
| 942 |
+
const std::vector<int>& axes,
|
| 943 |
+
StreamOrDevice s = {});
|
| 944 |
+
|
| 945 |
+
/** Softmax of an array. */
|
| 946 |
+
array softmax(const array& a, StreamOrDevice s = {});
|
| 947 |
+
|
| 948 |
+
/** Softmax of an array. */
|
| 949 |
+
inline array softmax(const array& a, int axis, StreamOrDevice s = {}) {
|
| 950 |
+
return softmax(a, std::vector<int>{axis}, s);
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
/** Raise elements of a to the power of b element-wise */
|
| 954 |
+
array power(const array& a, const array& b, StreamOrDevice s = {});
|
| 955 |
+
inline array operator^(const array& a, const array& b) {
|
| 956 |
+
return power(a, b);
|
| 957 |
+
}
|
| 958 |
+
template <typename T>
|
| 959 |
+
array operator^(T a, const array& b) {
|
| 960 |
+
return power(array(a), b);
|
| 961 |
+
}
|
| 962 |
+
template <typename T>
|
| 963 |
+
array operator^(const array& a, T b) {
|
| 964 |
+
return power(a, array(b));
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
/** Cumulative sum of an array. */
|
| 968 |
+
array cumsum(
|
| 969 |
+
const array& a,
|
| 970 |
+
int axis,
|
| 971 |
+
bool reverse = false,
|
| 972 |
+
bool inclusive = true,
|
| 973 |
+
StreamOrDevice s = {});
|
| 974 |
+
|
| 975 |
+
/** Cumulative product of an array. */
|
| 976 |
+
array cumprod(
|
| 977 |
+
const array& a,
|
| 978 |
+
int axis,
|
| 979 |
+
bool reverse = false,
|
| 980 |
+
bool inclusive = true,
|
| 981 |
+
StreamOrDevice s = {});
|
| 982 |
+
|
| 983 |
+
/** Cumulative max of an array. */
|
| 984 |
+
array cummax(
|
| 985 |
+
const array& a,
|
| 986 |
+
int axis,
|
| 987 |
+
bool reverse = false,
|
| 988 |
+
bool inclusive = true,
|
| 989 |
+
StreamOrDevice s = {});
|
| 990 |
+
|
| 991 |
+
/** Cumulative min of an array. */
|
| 992 |
+
array cummin(
|
| 993 |
+
const array& a,
|
| 994 |
+
int axis,
|
| 995 |
+
bool reverse = false,
|
| 996 |
+
bool inclusive = true,
|
| 997 |
+
StreamOrDevice s = {});
|
| 998 |
+
|
| 999 |
+
/** Convolution operations */
|
| 1000 |
+
|
| 1001 |
+
/** 1D convolution with a filter */
|
| 1002 |
+
array conv1d(
|
| 1003 |
+
const array& input,
|
| 1004 |
+
const array& weight,
|
| 1005 |
+
int stride = 1,
|
| 1006 |
+
int padding = 0,
|
| 1007 |
+
int dilation = 1,
|
| 1008 |
+
int groups = 1,
|
| 1009 |
+
StreamOrDevice s = {});
|
| 1010 |
+
|
| 1011 |
+
/** 2D convolution with a filter */
|
| 1012 |
+
array conv2d(
|
| 1013 |
+
const array& input,
|
| 1014 |
+
const array& weight,
|
| 1015 |
+
const std::pair<int, int>& stride = {1, 1},
|
| 1016 |
+
const std::pair<int, int>& padding = {0, 0},
|
| 1017 |
+
const std::pair<int, int>& dilation = {1, 1},
|
| 1018 |
+
int groups = 1,
|
| 1019 |
+
StreamOrDevice s = {});
|
| 1020 |
+
|
| 1021 |
+
/** Serialization operations */
|
| 1022 |
+
|
| 1023 |
+
/** Save array to out stream in .npy format */
|
| 1024 |
+
void save(
|
| 1025 |
+
std::shared_ptr<io::Writer> out_stream,
|
| 1026 |
+
array a,
|
| 1027 |
+
bool retain_graph = true);
|
| 1028 |
+
|
| 1029 |
+
/** Save array to file in .npy format */
|
| 1030 |
+
void save(const std::string& file, array a, bool retain_graph = true);
|
| 1031 |
+
|
| 1032 |
+
/** Load array from reader in .npy format */
|
| 1033 |
+
array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
|
| 1034 |
+
|
| 1035 |
+
/** Load array from file in .npy format */
|
| 1036 |
+
array load(const std::string& file, StreamOrDevice s = {});
|
| 1037 |
+
|
| 1038 |
+
/** Quantized matmul multiplies x with a quantized matrix w*/
|
| 1039 |
+
array quantized_matmul(
|
| 1040 |
+
const array& x,
|
| 1041 |
+
const array& w,
|
| 1042 |
+
const array& scales,
|
| 1043 |
+
const array& biases,
|
| 1044 |
+
bool transpose = true,
|
| 1045 |
+
int group_size = 64,
|
| 1046 |
+
int bits = 4,
|
| 1047 |
+
StreamOrDevice s = {});
|
| 1048 |
+
|
| 1049 |
+
/** Quantize a matrix along its last axis */
|
| 1050 |
+
std::tuple<array, array, array> quantize(
|
| 1051 |
+
const array& w,
|
| 1052 |
+
int group_size = 64,
|
| 1053 |
+
int bits = 4,
|
| 1054 |
+
StreamOrDevice s = {});
|
| 1055 |
+
|
| 1056 |
+
/** Dequantize a matrix produced by quantize() */
|
| 1057 |
+
array dequantize(
|
| 1058 |
+
const array& w,
|
| 1059 |
+
const array& scales,
|
| 1060 |
+
const array& biases,
|
| 1061 |
+
int group_size = 64,
|
| 1062 |
+
int bits = 4,
|
| 1063 |
+
StreamOrDevice s = {});
|
| 1064 |
+
|
| 1065 |
+
/** TensorDot returns a contraction of a and b over multiple dimensions. */
|
| 1066 |
+
array tensordot(
|
| 1067 |
+
const array& a,
|
| 1068 |
+
const array& b,
|
| 1069 |
+
const int dims = 2,
|
| 1070 |
+
StreamOrDevice s = {});
|
| 1071 |
+
|
| 1072 |
+
array tensordot(
|
| 1073 |
+
const array& a,
|
| 1074 |
+
const array& b,
|
| 1075 |
+
const std::pair<std::vector<int>, std::vector<int>>& dims,
|
| 1076 |
+
StreamOrDevice s = {});
|
| 1077 |
+
|
| 1078 |
+
/** Load array map from .safetensors file format */
|
| 1079 |
+
std::unordered_map<std::string, array> load_safetensors(
|
| 1080 |
+
std::shared_ptr<io::Reader> in_stream,
|
| 1081 |
+
StreamOrDevice s = {});
|
| 1082 |
+
std::unordered_map<std::string, array> load_safetensors(
|
| 1083 |
+
const std::string& file,
|
| 1084 |
+
StreamOrDevice s = {});
|
| 1085 |
+
|
| 1086 |
+
void save_safetensors(
|
| 1087 |
+
std::shared_ptr<io::Writer> in_stream,
|
| 1088 |
+
std::unordered_map<std::string, array>,
|
| 1089 |
+
std::optional<bool> retain_graph = std::nullopt);
|
| 1090 |
+
void save_safetensors(
|
| 1091 |
+
const std::string& file,
|
| 1092 |
+
std::unordered_map<std::string, array>,
|
| 1093 |
+
std::optional<bool> retain_graph = std::nullopt);
|
| 1094 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/primitives.h
ADDED
|
@@ -0,0 +1,1636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "array.h"
|
| 6 |
+
#include "device.h"
|
| 7 |
+
#include "io/load.h"
|
| 8 |
+
#include "stream.h"
|
| 9 |
+
|
| 10 |
+
#define DEFINE_GRADS() \
|
| 11 |
+
array jvp( \
|
| 12 |
+
const std::vector<array>& primals, \
|
| 13 |
+
const std::vector<array>& tangents, \
|
| 14 |
+
const std::vector<int>& argnums) override; \
|
| 15 |
+
\
|
| 16 |
+
std::vector<array> vjp( \
|
| 17 |
+
const std::vector<array>& primals, \
|
| 18 |
+
const array& cotan, \
|
| 19 |
+
const std::vector<int>& argnums) override;
|
| 20 |
+
|
| 21 |
+
#define DEFINE_PRINT(PRIMITIVE) \
|
| 22 |
+
void print(std::ostream& os) override { \
|
| 23 |
+
os << #PRIMITIVE; \
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
#define DEFINE_DEFAULT_IS_EQUIVALENT() \
|
| 27 |
+
bool is_equivalent(const Primitive& other) const override { \
|
| 28 |
+
return true; \
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
namespace mlx::core {
|
| 32 |
+
|
| 33 |
+
// Abstract base class
|
| 34 |
+
class Primitive {
|
| 35 |
+
public:
|
| 36 |
+
explicit Primitive(Stream stream) : stream_(stream) {}
|
| 37 |
+
|
| 38 |
+
/** The device the primitive will run on. */
|
| 39 |
+
const Device& device() {
|
| 40 |
+
return stream().device;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/** The stream the primitive will run on. */
|
| 44 |
+
const Stream& stream() {
|
| 45 |
+
return stream_;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/**
|
| 49 |
+
* A primitive must know how to evaluate itself on
|
| 50 |
+
* the CPU/GPU for the given inputs and populate the output array.
|
| 51 |
+
*
|
| 52 |
+
* To avoid unnecessary allocations, the evaluation function
|
| 53 |
+
* is responsible for allocating space for the array.
|
| 54 |
+
*/
|
| 55 |
+
virtual void eval_cpu(const std::vector<array>& inputs, array& out) = 0;
|
| 56 |
+
virtual void eval_gpu(const std::vector<array>& inputs, array& out) = 0;
|
| 57 |
+
|
| 58 |
+
/**
|
| 59 |
+
* The Jacobian-vector product.
|
| 60 |
+
*/
|
| 61 |
+
virtual array jvp(
|
| 62 |
+
const std::vector<array>& primals,
|
| 63 |
+
const std::vector<array>& tangents,
|
| 64 |
+
const std::vector<int>& argnums);
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* The vector-Jacobian product.
|
| 68 |
+
*/
|
| 69 |
+
virtual std::vector<array> vjp(
|
| 70 |
+
const std::vector<array>& primals,
|
| 71 |
+
const array& cotan,
|
| 72 |
+
const std::vector<int>& argnums);
|
| 73 |
+
|
| 74 |
+
/**
|
| 75 |
+
* The primitive must know how to vectorize itself across
|
| 76 |
+
* the given axes. The output is a pair containing the array
|
| 77 |
+
* representing the vectorized computation and the axis which
|
| 78 |
+
* corresponds to the output vectorized dimension.
|
| 79 |
+
*/
|
| 80 |
+
virtual std::pair<array, int> vmap(
|
| 81 |
+
const std::vector<array>& inputs,
|
| 82 |
+
const std::vector<int>& axes);
|
| 83 |
+
|
| 84 |
+
/** Print the primitive. */
|
| 85 |
+
virtual void print(std::ostream& os) = 0;
|
| 86 |
+
|
| 87 |
+
/** Equivalence check defaults to false unless overridden by the primitive */
|
| 88 |
+
virtual bool is_equivalent(const Primitive& other) const {
|
| 89 |
+
return false;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
virtual ~Primitive() = default;
|
| 93 |
+
Primitive(const Primitive& other) = delete;
|
| 94 |
+
Primitive(Primitive&& other) = delete;
|
| 95 |
+
Primitive& operator=(const Primitive& other) = delete;
|
| 96 |
+
Primitive& operator=(Primitive&& other) = delete;
|
| 97 |
+
|
| 98 |
+
private:
|
| 99 |
+
// Every primitive stores the stream it should run in
|
| 100 |
+
Stream stream_;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
class Abs : public Primitive {
|
| 104 |
+
public:
|
| 105 |
+
explicit Abs(Stream stream) : Primitive(stream){};
|
| 106 |
+
|
| 107 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 108 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 109 |
+
|
| 110 |
+
std::pair<array, int> vmap(
|
| 111 |
+
const std::vector<array>& inputs,
|
| 112 |
+
const std::vector<int>& axes) override;
|
| 113 |
+
|
| 114 |
+
DEFINE_GRADS()
|
| 115 |
+
DEFINE_PRINT(Abs)
|
| 116 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 117 |
+
|
| 118 |
+
private:
|
| 119 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
class Add : public Primitive {
|
| 123 |
+
public:
|
| 124 |
+
explicit Add(Stream stream) : Primitive(stream){};
|
| 125 |
+
|
| 126 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 127 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 128 |
+
|
| 129 |
+
std::pair<array, int> vmap(
|
| 130 |
+
const std::vector<array>& inputs,
|
| 131 |
+
const std::vector<int>& axes) override;
|
| 132 |
+
|
| 133 |
+
DEFINE_GRADS()
|
| 134 |
+
DEFINE_PRINT(Add)
|
| 135 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 136 |
+
|
| 137 |
+
private:
|
| 138 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
class Arange : public Primitive {
|
| 142 |
+
public:
|
| 143 |
+
explicit Arange(Stream stream, double start, double stop, double step)
|
| 144 |
+
: Primitive(stream), start_(start), stop_(stop), step_(step){};
|
| 145 |
+
|
| 146 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 147 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 148 |
+
|
| 149 |
+
DEFINE_PRINT(Arange)
|
| 150 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 151 |
+
|
| 152 |
+
private:
|
| 153 |
+
double start_;
|
| 154 |
+
double stop_;
|
| 155 |
+
double step_;
|
| 156 |
+
|
| 157 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
class ArcCos : public Primitive {
|
| 161 |
+
public:
|
| 162 |
+
explicit ArcCos(Stream stream) : Primitive(stream){};
|
| 163 |
+
|
| 164 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 165 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 166 |
+
|
| 167 |
+
std::pair<array, int> vmap(
|
| 168 |
+
const std::vector<array>& inputs,
|
| 169 |
+
const std::vector<int>& axes) override;
|
| 170 |
+
|
| 171 |
+
DEFINE_GRADS()
|
| 172 |
+
DEFINE_PRINT(ArcCos)
|
| 173 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 174 |
+
|
| 175 |
+
private:
|
| 176 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 177 |
+
};
|
| 178 |
+
|
| 179 |
+
class ArcCosh : public Primitive {
|
| 180 |
+
public:
|
| 181 |
+
explicit ArcCosh(Stream stream) : Primitive(stream){};
|
| 182 |
+
|
| 183 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 184 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 185 |
+
|
| 186 |
+
std::pair<array, int> vmap(
|
| 187 |
+
const std::vector<array>& inputs,
|
| 188 |
+
const std::vector<int>& axes) override;
|
| 189 |
+
|
| 190 |
+
DEFINE_GRADS()
|
| 191 |
+
DEFINE_PRINT(ArcCosh)
|
| 192 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 193 |
+
|
| 194 |
+
private:
|
| 195 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 196 |
+
};
|
| 197 |
+
|
| 198 |
+
class ArcSin : public Primitive {
|
| 199 |
+
public:
|
| 200 |
+
explicit ArcSin(Stream stream) : Primitive(stream){};
|
| 201 |
+
|
| 202 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 203 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 204 |
+
|
| 205 |
+
std::pair<array, int> vmap(
|
| 206 |
+
const std::vector<array>& inputs,
|
| 207 |
+
const std::vector<int>& axes) override;
|
| 208 |
+
|
| 209 |
+
DEFINE_GRADS()
|
| 210 |
+
DEFINE_PRINT(ArcSin)
|
| 211 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 212 |
+
|
| 213 |
+
private:
|
| 214 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
class ArcSinh : public Primitive {
|
| 218 |
+
public:
|
| 219 |
+
explicit ArcSinh(Stream stream) : Primitive(stream){};
|
| 220 |
+
|
| 221 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 222 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 223 |
+
|
| 224 |
+
std::pair<array, int> vmap(
|
| 225 |
+
const std::vector<array>& inputs,
|
| 226 |
+
const std::vector<int>& axes) override;
|
| 227 |
+
|
| 228 |
+
DEFINE_GRADS()
|
| 229 |
+
DEFINE_PRINT(ArcSinh)
|
| 230 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 231 |
+
|
| 232 |
+
private:
|
| 233 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 234 |
+
};
|
| 235 |
+
|
| 236 |
+
class ArcTan : public Primitive {
|
| 237 |
+
public:
|
| 238 |
+
explicit ArcTan(Stream stream) : Primitive(stream){};
|
| 239 |
+
|
| 240 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 241 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 242 |
+
|
| 243 |
+
std::pair<array, int> vmap(
|
| 244 |
+
const std::vector<array>& inputs,
|
| 245 |
+
const std::vector<int>& axes) override;
|
| 246 |
+
|
| 247 |
+
DEFINE_GRADS()
|
| 248 |
+
DEFINE_PRINT(ArcTan)
|
| 249 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 250 |
+
|
| 251 |
+
private:
|
| 252 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
class ArcTanh : public Primitive {
|
| 256 |
+
public:
|
| 257 |
+
explicit ArcTanh(Stream stream) : Primitive(stream){};
|
| 258 |
+
|
| 259 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 260 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 261 |
+
|
| 262 |
+
std::pair<array, int> vmap(
|
| 263 |
+
const std::vector<array>& inputs,
|
| 264 |
+
const std::vector<int>& axes) override;
|
| 265 |
+
|
| 266 |
+
DEFINE_GRADS()
|
| 267 |
+
DEFINE_PRINT(ArcTanh)
|
| 268 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 269 |
+
|
| 270 |
+
private:
|
| 271 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 272 |
+
};
|
| 273 |
+
|
| 274 |
+
class ArgPartition : public Primitive {
|
| 275 |
+
public:
|
| 276 |
+
explicit ArgPartition(Stream stream, int kth, int axis)
|
| 277 |
+
: Primitive(stream), kth_(kth), axis_(axis){};
|
| 278 |
+
|
| 279 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 280 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 281 |
+
|
| 282 |
+
std::pair<array, int> vmap(
|
| 283 |
+
const std::vector<array>& inputs,
|
| 284 |
+
const std::vector<int>& axes) override;
|
| 285 |
+
|
| 286 |
+
DEFINE_PRINT(ArgPartition)
|
| 287 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 288 |
+
|
| 289 |
+
private:
|
| 290 |
+
int kth_;
|
| 291 |
+
int axis_;
|
| 292 |
+
|
| 293 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
class ArgReduce : public Primitive {
|
| 297 |
+
public:
|
| 298 |
+
enum ReduceType {
|
| 299 |
+
ArgMin,
|
| 300 |
+
ArgMax,
|
| 301 |
+
};
|
| 302 |
+
|
| 303 |
+
explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
|
| 304 |
+
: Primitive(stream), reduce_type_(reduce_type), axis_(axis){};
|
| 305 |
+
|
| 306 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 307 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 308 |
+
|
| 309 |
+
DEFINE_PRINT(ArgReduce)
|
| 310 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 311 |
+
|
| 312 |
+
private:
|
| 313 |
+
ReduceType reduce_type_;
|
| 314 |
+
int axis_;
|
| 315 |
+
|
| 316 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 317 |
+
};
|
| 318 |
+
|
| 319 |
+
class ArgSort : public Primitive {
|
| 320 |
+
public:
|
| 321 |
+
explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
|
| 322 |
+
|
| 323 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 324 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 325 |
+
|
| 326 |
+
std::pair<array, int> vmap(
|
| 327 |
+
const std::vector<array>& inputs,
|
| 328 |
+
const std::vector<int>& axes) override;
|
| 329 |
+
|
| 330 |
+
DEFINE_PRINT(ArgSort)
|
| 331 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 332 |
+
|
| 333 |
+
private:
|
| 334 |
+
int axis_;
|
| 335 |
+
|
| 336 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 337 |
+
};
|
| 338 |
+
|
| 339 |
+
class AsType : public Primitive {
|
| 340 |
+
public:
|
| 341 |
+
explicit AsType(Stream stream, Dtype dtype)
|
| 342 |
+
: Primitive(stream), dtype_(dtype){};
|
| 343 |
+
|
| 344 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 345 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 346 |
+
|
| 347 |
+
std::pair<array, int> vmap(
|
| 348 |
+
const std::vector<array>& inputs,
|
| 349 |
+
const std::vector<int>& axes) override;
|
| 350 |
+
|
| 351 |
+
DEFINE_GRADS()
|
| 352 |
+
DEFINE_PRINT(AsType)
|
| 353 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 354 |
+
|
| 355 |
+
private:
|
| 356 |
+
Dtype dtype_;
|
| 357 |
+
|
| 358 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 359 |
+
};
|
| 360 |
+
|
| 361 |
+
class AsStrided : public Primitive {
|
| 362 |
+
public:
|
| 363 |
+
explicit AsStrided(
|
| 364 |
+
Stream stream,
|
| 365 |
+
const std::vector<int>& shape,
|
| 366 |
+
const std::vector<size_t>& strides,
|
| 367 |
+
size_t offset)
|
| 368 |
+
: Primitive(stream), shape_(shape), strides_(strides), offset_(offset){};
|
| 369 |
+
|
| 370 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 371 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 372 |
+
|
| 373 |
+
DEFINE_GRADS()
|
| 374 |
+
DEFINE_PRINT(AsStrided)
|
| 375 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 376 |
+
|
| 377 |
+
private:
|
| 378 |
+
std::vector<int> shape_;
|
| 379 |
+
std::vector<size_t> strides_;
|
| 380 |
+
size_t offset_;
|
| 381 |
+
|
| 382 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
class Broadcast : public Primitive {
|
| 386 |
+
public:
|
| 387 |
+
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
| 388 |
+
: Primitive(stream), shape_(shape){};
|
| 389 |
+
|
| 390 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 391 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 392 |
+
|
| 393 |
+
std::pair<array, int> vmap(
|
| 394 |
+
const std::vector<array>& inputs,
|
| 395 |
+
const std::vector<int>& axes) override;
|
| 396 |
+
|
| 397 |
+
DEFINE_GRADS()
|
| 398 |
+
DEFINE_PRINT(Broadcast)
|
| 399 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 400 |
+
|
| 401 |
+
private:
|
| 402 |
+
std::vector<int> shape_;
|
| 403 |
+
|
| 404 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 405 |
+
};
|
| 406 |
+
|
| 407 |
+
class Ceil : public Primitive {
|
| 408 |
+
public:
|
| 409 |
+
explicit Ceil(Stream stream) : Primitive(stream){};
|
| 410 |
+
|
| 411 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 412 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 413 |
+
|
| 414 |
+
std::pair<array, int> vmap(
|
| 415 |
+
const std::vector<array>& inputs,
|
| 416 |
+
const std::vector<int>& axes) override;
|
| 417 |
+
|
| 418 |
+
DEFINE_GRADS()
|
| 419 |
+
DEFINE_PRINT(Ceil)
|
| 420 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 421 |
+
|
| 422 |
+
private:
|
| 423 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 424 |
+
};
|
| 425 |
+
|
| 426 |
+
class Concatenate : public Primitive {
|
| 427 |
+
public:
|
| 428 |
+
explicit Concatenate(Stream stream, int axis)
|
| 429 |
+
: Primitive(stream), axis_(axis){};
|
| 430 |
+
|
| 431 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 432 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 433 |
+
|
| 434 |
+
std::pair<array, int> vmap(
|
| 435 |
+
const std::vector<array>& inputs,
|
| 436 |
+
const std::vector<int>& axes) override;
|
| 437 |
+
|
| 438 |
+
DEFINE_GRADS()
|
| 439 |
+
DEFINE_PRINT(Concatenate)
|
| 440 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 441 |
+
|
| 442 |
+
private:
|
| 443 |
+
int axis_;
|
| 444 |
+
|
| 445 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 446 |
+
};
|
| 447 |
+
|
| 448 |
+
class Convolution : public Primitive {
|
| 449 |
+
public:
|
| 450 |
+
explicit Convolution(
|
| 451 |
+
Stream stream,
|
| 452 |
+
const std::vector<int>& padding,
|
| 453 |
+
const std::vector<int>& kernel_strides,
|
| 454 |
+
const std::vector<int>& kernel_dilation,
|
| 455 |
+
const std::vector<int>& input_dilation)
|
| 456 |
+
: Primitive(stream),
|
| 457 |
+
padding_(padding),
|
| 458 |
+
kernel_strides_(kernel_strides),
|
| 459 |
+
kernel_dilation_(kernel_dilation),
|
| 460 |
+
input_dilation_(input_dilation){};
|
| 461 |
+
|
| 462 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 463 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 464 |
+
|
| 465 |
+
std::vector<array> vjp(
|
| 466 |
+
const std::vector<array>& primals,
|
| 467 |
+
const array& cotan,
|
| 468 |
+
const std::vector<int>& argnums) override;
|
| 469 |
+
|
| 470 |
+
DEFINE_PRINT(Convolution)
|
| 471 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 472 |
+
|
| 473 |
+
private:
|
| 474 |
+
std::vector<int> padding_;
|
| 475 |
+
std::vector<int> kernel_strides_;
|
| 476 |
+
std::vector<int> kernel_dilation_;
|
| 477 |
+
std::vector<int> input_dilation_;
|
| 478 |
+
|
| 479 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 480 |
+
};
|
| 481 |
+
|
| 482 |
+
class Copy : public Primitive {
|
| 483 |
+
public:
|
| 484 |
+
explicit Copy(Stream stream) : Primitive(stream){};
|
| 485 |
+
|
| 486 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 487 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 488 |
+
|
| 489 |
+
std::pair<array, int> vmap(
|
| 490 |
+
const std::vector<array>& inputs,
|
| 491 |
+
const std::vector<int>& axes) override;
|
| 492 |
+
|
| 493 |
+
DEFINE_GRADS()
|
| 494 |
+
DEFINE_PRINT(Copy)
|
| 495 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 496 |
+
|
| 497 |
+
private:
|
| 498 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 499 |
+
};
|
| 500 |
+
|
| 501 |
+
class Cos : public Primitive {
|
| 502 |
+
public:
|
| 503 |
+
explicit Cos(Stream stream) : Primitive(stream){};
|
| 504 |
+
|
| 505 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 506 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 507 |
+
|
| 508 |
+
std::pair<array, int> vmap(
|
| 509 |
+
const std::vector<array>& inputs,
|
| 510 |
+
const std::vector<int>& axes) override;
|
| 511 |
+
|
| 512 |
+
DEFINE_GRADS()
|
| 513 |
+
DEFINE_PRINT(Cos)
|
| 514 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 515 |
+
|
| 516 |
+
private:
|
| 517 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 518 |
+
};
|
| 519 |
+
|
| 520 |
+
class Cosh : public Primitive {
|
| 521 |
+
public:
|
| 522 |
+
explicit Cosh(Stream stream) : Primitive(stream){};
|
| 523 |
+
|
| 524 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 525 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 526 |
+
|
| 527 |
+
std::pair<array, int> vmap(
|
| 528 |
+
const std::vector<array>& inputs,
|
| 529 |
+
const std::vector<int>& axes) override;
|
| 530 |
+
|
| 531 |
+
DEFINE_GRADS()
|
| 532 |
+
DEFINE_PRINT(Cosh)
|
| 533 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 534 |
+
|
| 535 |
+
private:
|
| 536 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 537 |
+
};
|
| 538 |
+
|
| 539 |
+
class Divide : public Primitive {
|
| 540 |
+
public:
|
| 541 |
+
explicit Divide(Stream stream) : Primitive(stream){};
|
| 542 |
+
|
| 543 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 544 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 545 |
+
|
| 546 |
+
std::pair<array, int> vmap(
|
| 547 |
+
const std::vector<array>& inputs,
|
| 548 |
+
const std::vector<int>& axes) override;
|
| 549 |
+
|
| 550 |
+
DEFINE_GRADS()
|
| 551 |
+
DEFINE_PRINT(Divide)
|
| 552 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 553 |
+
|
| 554 |
+
private:
|
| 555 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 556 |
+
};
|
| 557 |
+
|
| 558 |
+
class Remainder : public Primitive {
|
| 559 |
+
public:
|
| 560 |
+
explicit Remainder(Stream stream) : Primitive(stream){};
|
| 561 |
+
|
| 562 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 563 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 564 |
+
|
| 565 |
+
std::pair<array, int> vmap(
|
| 566 |
+
const std::vector<array>& inputs,
|
| 567 |
+
const std::vector<int>& axes) override;
|
| 568 |
+
|
| 569 |
+
DEFINE_GRADS()
|
| 570 |
+
DEFINE_PRINT(Remainder)
|
| 571 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 572 |
+
|
| 573 |
+
private:
|
| 574 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 575 |
+
};
|
| 576 |
+
|
| 577 |
+
class Equal : public Primitive {
|
| 578 |
+
public:
|
| 579 |
+
explicit Equal(Stream stream, bool equal_nan = false)
|
| 580 |
+
: Primitive(stream), equal_nan_(equal_nan){};
|
| 581 |
+
|
| 582 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 583 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 584 |
+
|
| 585 |
+
std::pair<array, int> vmap(
|
| 586 |
+
const std::vector<array>& inputs,
|
| 587 |
+
const std::vector<int>& axes) override;
|
| 588 |
+
|
| 589 |
+
DEFINE_GRADS()
|
| 590 |
+
DEFINE_PRINT(Equal)
|
| 591 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 592 |
+
|
| 593 |
+
private:
|
| 594 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 595 |
+
bool equal_nan_;
|
| 596 |
+
};
|
| 597 |
+
|
| 598 |
+
class Erf : public Primitive {
|
| 599 |
+
public:
|
| 600 |
+
explicit Erf(Stream stream) : Primitive(stream){};
|
| 601 |
+
|
| 602 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 603 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 604 |
+
|
| 605 |
+
std::pair<array, int> vmap(
|
| 606 |
+
const std::vector<array>& inputs,
|
| 607 |
+
const std::vector<int>& axes) override;
|
| 608 |
+
|
| 609 |
+
DEFINE_GRADS()
|
| 610 |
+
DEFINE_PRINT(Erf)
|
| 611 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 612 |
+
|
| 613 |
+
private:
|
| 614 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 615 |
+
};
|
| 616 |
+
|
| 617 |
+
class ErfInv : public Primitive {
|
| 618 |
+
public:
|
| 619 |
+
explicit ErfInv(Stream stream) : Primitive(stream){};
|
| 620 |
+
|
| 621 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 622 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 623 |
+
|
| 624 |
+
std::pair<array, int> vmap(
|
| 625 |
+
const std::vector<array>& inputs,
|
| 626 |
+
const std::vector<int>& axes) override;
|
| 627 |
+
|
| 628 |
+
DEFINE_GRADS()
|
| 629 |
+
DEFINE_PRINT(ErfInv)
|
| 630 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 631 |
+
|
| 632 |
+
private:
|
| 633 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 634 |
+
};
|
| 635 |
+
|
| 636 |
+
class Exp : public Primitive {
|
| 637 |
+
public:
|
| 638 |
+
explicit Exp(Stream stream) : Primitive(stream){};
|
| 639 |
+
|
| 640 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 641 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 642 |
+
|
| 643 |
+
std::pair<array, int> vmap(
|
| 644 |
+
const std::vector<array>& inputs,
|
| 645 |
+
const std::vector<int>& axes) override;
|
| 646 |
+
|
| 647 |
+
DEFINE_GRADS()
|
| 648 |
+
DEFINE_PRINT(Exp)
|
| 649 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 650 |
+
|
| 651 |
+
private:
|
| 652 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 653 |
+
};
|
| 654 |
+
|
| 655 |
+
class FFT : public Primitive {
|
| 656 |
+
public:
|
| 657 |
+
explicit FFT(
|
| 658 |
+
Stream stream,
|
| 659 |
+
const std::vector<size_t>& axes,
|
| 660 |
+
bool inverse,
|
| 661 |
+
bool real)
|
| 662 |
+
: Primitive(stream), axes_(axes), inverse_(inverse), real_(real){};
|
| 663 |
+
|
| 664 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 665 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 666 |
+
|
| 667 |
+
std::pair<array, int> vmap(
|
| 668 |
+
const std::vector<array>& inputs,
|
| 669 |
+
const std::vector<int>& axes) override;
|
| 670 |
+
|
| 671 |
+
DEFINE_GRADS()
|
| 672 |
+
DEFINE_PRINT(FFT)
|
| 673 |
+
|
| 674 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 675 |
+
|
| 676 |
+
private:
|
| 677 |
+
std::vector<size_t> axes_;
|
| 678 |
+
bool inverse_;
|
| 679 |
+
bool real_;
|
| 680 |
+
|
| 681 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 682 |
+
};
|
| 683 |
+
|
| 684 |
+
class Floor : public Primitive {
|
| 685 |
+
public:
|
| 686 |
+
explicit Floor(Stream stream) : Primitive(stream){};
|
| 687 |
+
|
| 688 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 689 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 690 |
+
|
| 691 |
+
std::pair<array, int> vmap(
|
| 692 |
+
const std::vector<array>& inputs,
|
| 693 |
+
const std::vector<int>& axes) override;
|
| 694 |
+
|
| 695 |
+
DEFINE_GRADS()
|
| 696 |
+
DEFINE_PRINT(Floor)
|
| 697 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 698 |
+
|
| 699 |
+
private:
|
| 700 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 701 |
+
};
|
| 702 |
+
|
| 703 |
+
class Full : public Primitive {
|
| 704 |
+
public:
|
| 705 |
+
explicit Full(Stream stream) : Primitive(stream){};
|
| 706 |
+
|
| 707 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 708 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 709 |
+
|
| 710 |
+
std::pair<array, int> vmap(
|
| 711 |
+
const std::vector<array>& inputs,
|
| 712 |
+
const std::vector<int>& axes) override;
|
| 713 |
+
|
| 714 |
+
DEFINE_GRADS()
|
| 715 |
+
DEFINE_PRINT(Full)
|
| 716 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 717 |
+
|
| 718 |
+
private:
|
| 719 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 720 |
+
};
|
| 721 |
+
|
| 722 |
+
class Gather : public Primitive {
|
| 723 |
+
public:
|
| 724 |
+
explicit Gather(
|
| 725 |
+
Stream stream,
|
| 726 |
+
const std::vector<int>& axes,
|
| 727 |
+
const std::vector<int>& slice_sizes)
|
| 728 |
+
: Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){};
|
| 729 |
+
|
| 730 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 731 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 732 |
+
|
| 733 |
+
std::pair<array, int> vmap(
|
| 734 |
+
const std::vector<array>& inputs,
|
| 735 |
+
const std::vector<int>& axes) override;
|
| 736 |
+
|
| 737 |
+
DEFINE_GRADS()
|
| 738 |
+
DEFINE_PRINT(Gather)
|
| 739 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 740 |
+
|
| 741 |
+
private:
|
| 742 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 743 |
+
std::vector<int> axes_;
|
| 744 |
+
std::vector<int> slice_sizes_;
|
| 745 |
+
};
|
| 746 |
+
|
| 747 |
+
class Greater : public Primitive {
|
| 748 |
+
public:
|
| 749 |
+
explicit Greater(Stream stream) : Primitive(stream){};
|
| 750 |
+
|
| 751 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 752 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 753 |
+
|
| 754 |
+
std::pair<array, int> vmap(
|
| 755 |
+
const std::vector<array>& inputs,
|
| 756 |
+
const std::vector<int>& axes) override;
|
| 757 |
+
|
| 758 |
+
DEFINE_GRADS()
|
| 759 |
+
DEFINE_PRINT(Greater)
|
| 760 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 761 |
+
|
| 762 |
+
private:
|
| 763 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 764 |
+
};
|
| 765 |
+
|
| 766 |
+
class GreaterEqual : public Primitive {
|
| 767 |
+
public:
|
| 768 |
+
explicit GreaterEqual(Stream stream) : Primitive(stream){};
|
| 769 |
+
|
| 770 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 771 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 772 |
+
|
| 773 |
+
std::pair<array, int> vmap(
|
| 774 |
+
const std::vector<array>& inputs,
|
| 775 |
+
const std::vector<int>& axes) override;
|
| 776 |
+
|
| 777 |
+
DEFINE_GRADS()
|
| 778 |
+
DEFINE_PRINT(GreaterEqual)
|
| 779 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 780 |
+
|
| 781 |
+
private:
|
| 782 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 783 |
+
};
|
| 784 |
+
|
| 785 |
+
class Less : public Primitive {
|
| 786 |
+
public:
|
| 787 |
+
explicit Less(Stream stream) : Primitive(stream){};
|
| 788 |
+
|
| 789 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 790 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 791 |
+
|
| 792 |
+
std::pair<array, int> vmap(
|
| 793 |
+
const std::vector<array>& inputs,
|
| 794 |
+
const std::vector<int>& axes) override;
|
| 795 |
+
|
| 796 |
+
DEFINE_GRADS()
|
| 797 |
+
DEFINE_PRINT(Less)
|
| 798 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 799 |
+
|
| 800 |
+
private:
|
| 801 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 802 |
+
};
|
| 803 |
+
|
| 804 |
+
class LessEqual : public Primitive {
|
| 805 |
+
public:
|
| 806 |
+
explicit LessEqual(Stream stream) : Primitive(stream){};
|
| 807 |
+
|
| 808 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 809 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 810 |
+
|
| 811 |
+
std::pair<array, int> vmap(
|
| 812 |
+
const std::vector<array>& inputs,
|
| 813 |
+
const std::vector<int>& axes) override;
|
| 814 |
+
|
| 815 |
+
DEFINE_GRADS()
|
| 816 |
+
DEFINE_PRINT(LessEqual)
|
| 817 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 818 |
+
|
| 819 |
+
private:
|
| 820 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 821 |
+
};
|
| 822 |
+
|
| 823 |
+
class Load : public Primitive {
|
| 824 |
+
public:
|
| 825 |
+
explicit Load(
|
| 826 |
+
Stream stream,
|
| 827 |
+
std::shared_ptr<io::Reader> reader,
|
| 828 |
+
size_t offset,
|
| 829 |
+
bool swap_endianness = false)
|
| 830 |
+
: Primitive(stream),
|
| 831 |
+
reader_(reader),
|
| 832 |
+
offset_(offset),
|
| 833 |
+
swap_endianness_(swap_endianness){};
|
| 834 |
+
|
| 835 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 836 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 837 |
+
|
| 838 |
+
DEFINE_PRINT(Load)
|
| 839 |
+
|
| 840 |
+
private:
|
| 841 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 842 |
+
std::shared_ptr<io::Reader> reader_;
|
| 843 |
+
size_t offset_;
|
| 844 |
+
bool swap_endianness_;
|
| 845 |
+
};
|
| 846 |
+
|
| 847 |
+
class Log : public Primitive {
|
| 848 |
+
public:
|
| 849 |
+
enum Base { two, ten, e };
|
| 850 |
+
|
| 851 |
+
explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){};
|
| 852 |
+
|
| 853 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 854 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 855 |
+
|
| 856 |
+
std::pair<array, int> vmap(
|
| 857 |
+
const std::vector<array>& inputs,
|
| 858 |
+
const std::vector<int>& axes) override;
|
| 859 |
+
|
| 860 |
+
DEFINE_GRADS()
|
| 861 |
+
DEFINE_PRINT(Log)
|
| 862 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 863 |
+
|
| 864 |
+
private:
|
| 865 |
+
Base base_;
|
| 866 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 867 |
+
};
|
| 868 |
+
|
| 869 |
+
class Log1p : public Primitive {
|
| 870 |
+
public:
|
| 871 |
+
explicit Log1p(Stream stream) : Primitive(stream){};
|
| 872 |
+
|
| 873 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 874 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 875 |
+
|
| 876 |
+
std::pair<array, int> vmap(
|
| 877 |
+
const std::vector<array>& inputs,
|
| 878 |
+
const std::vector<int>& axes) override;
|
| 879 |
+
|
| 880 |
+
DEFINE_GRADS()
|
| 881 |
+
DEFINE_PRINT(Log1p)
|
| 882 |
+
|
| 883 |
+
private:
|
| 884 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 885 |
+
};
|
| 886 |
+
|
| 887 |
+
class LogicalNot : public Primitive {
|
| 888 |
+
public:
|
| 889 |
+
explicit LogicalNot(Stream stream) : Primitive(stream){};
|
| 890 |
+
|
| 891 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 892 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 893 |
+
|
| 894 |
+
std::pair<array, int> vmap(
|
| 895 |
+
const std::vector<array>& inputs,
|
| 896 |
+
const std::vector<int>& axes) override;
|
| 897 |
+
|
| 898 |
+
DEFINE_GRADS()
|
| 899 |
+
DEFINE_PRINT(LogicalNot)
|
| 900 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 901 |
+
|
| 902 |
+
private:
|
| 903 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 904 |
+
};
|
| 905 |
+
|
| 906 |
+
class LogAddExp : public Primitive {
|
| 907 |
+
public:
|
| 908 |
+
explicit LogAddExp(Stream stream) : Primitive(stream){};
|
| 909 |
+
|
| 910 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 911 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 912 |
+
|
| 913 |
+
std::pair<array, int> vmap(
|
| 914 |
+
const std::vector<array>& inputs,
|
| 915 |
+
const std::vector<int>& axes) override;
|
| 916 |
+
|
| 917 |
+
DEFINE_GRADS()
|
| 918 |
+
DEFINE_PRINT(LogAddExp)
|
| 919 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 920 |
+
|
| 921 |
+
private:
|
| 922 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 923 |
+
};
|
| 924 |
+
|
| 925 |
+
class Matmul : public Primitive {
|
| 926 |
+
public:
|
| 927 |
+
explicit Matmul(Stream stream) : Primitive(stream){};
|
| 928 |
+
|
| 929 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 930 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 931 |
+
|
| 932 |
+
std::vector<array> vjp(
|
| 933 |
+
const std::vector<array>& primals,
|
| 934 |
+
const array& cotan,
|
| 935 |
+
const std::vector<int>& argnums) override;
|
| 936 |
+
|
| 937 |
+
std::pair<array, int> vmap(
|
| 938 |
+
const std::vector<array>& inputs,
|
| 939 |
+
const std::vector<int>& axes) override;
|
| 940 |
+
|
| 941 |
+
DEFINE_PRINT(Matmul)
|
| 942 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 943 |
+
};
|
| 944 |
+
|
| 945 |
+
class Maximum : public Primitive {
|
| 946 |
+
public:
|
| 947 |
+
explicit Maximum(Stream stream) : Primitive(stream){};
|
| 948 |
+
|
| 949 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 950 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 951 |
+
|
| 952 |
+
std::pair<array, int> vmap(
|
| 953 |
+
const std::vector<array>& inputs,
|
| 954 |
+
const std::vector<int>& axes) override;
|
| 955 |
+
|
| 956 |
+
DEFINE_GRADS()
|
| 957 |
+
DEFINE_PRINT(Maximum)
|
| 958 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 959 |
+
|
| 960 |
+
private:
|
| 961 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 962 |
+
};
|
| 963 |
+
|
| 964 |
+
class Minimum : public Primitive {
|
| 965 |
+
public:
|
| 966 |
+
explicit Minimum(Stream stream) : Primitive(stream){};
|
| 967 |
+
|
| 968 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 969 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 970 |
+
|
| 971 |
+
std::pair<array, int> vmap(
|
| 972 |
+
const std::vector<array>& inputs,
|
| 973 |
+
const std::vector<int>& axes) override;
|
| 974 |
+
|
| 975 |
+
DEFINE_GRADS()
|
| 976 |
+
DEFINE_PRINT(Minimum)
|
| 977 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 978 |
+
|
| 979 |
+
private:
|
| 980 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 981 |
+
};
|
| 982 |
+
|
| 983 |
+
class Multiply : public Primitive {
|
| 984 |
+
public:
|
| 985 |
+
explicit Multiply(Stream stream) : Primitive(stream){};
|
| 986 |
+
|
| 987 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 988 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 989 |
+
|
| 990 |
+
std::pair<array, int> vmap(
|
| 991 |
+
const std::vector<array>& inputs,
|
| 992 |
+
const std::vector<int>& axes) override;
|
| 993 |
+
|
| 994 |
+
DEFINE_GRADS()
|
| 995 |
+
DEFINE_PRINT(Multiply)
|
| 996 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 997 |
+
|
| 998 |
+
private:
|
| 999 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1000 |
+
};
|
| 1001 |
+
|
| 1002 |
+
class Negative : public Primitive {
|
| 1003 |
+
public:
|
| 1004 |
+
explicit Negative(Stream stream) : Primitive(stream){};
|
| 1005 |
+
|
| 1006 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1007 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1008 |
+
|
| 1009 |
+
std::pair<array, int> vmap(
|
| 1010 |
+
const std::vector<array>& inputs,
|
| 1011 |
+
const std::vector<int>& axes) override;
|
| 1012 |
+
|
| 1013 |
+
DEFINE_GRADS()
|
| 1014 |
+
DEFINE_PRINT(Negative)
|
| 1015 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1016 |
+
|
| 1017 |
+
private:
|
| 1018 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1019 |
+
};
|
| 1020 |
+
|
| 1021 |
+
class NotEqual : public Primitive {
|
| 1022 |
+
public:
|
| 1023 |
+
explicit NotEqual(Stream stream) : Primitive(stream){};
|
| 1024 |
+
|
| 1025 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1026 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1027 |
+
|
| 1028 |
+
std::pair<array, int> vmap(
|
| 1029 |
+
const std::vector<array>& inputs,
|
| 1030 |
+
const std::vector<int>& axes) override;
|
| 1031 |
+
|
| 1032 |
+
DEFINE_GRADS()
|
| 1033 |
+
DEFINE_PRINT(NotEqual)
|
| 1034 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1035 |
+
|
| 1036 |
+
private:
|
| 1037 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1038 |
+
};
|
| 1039 |
+
|
| 1040 |
+
class Pad : public Primitive {
|
| 1041 |
+
public:
|
| 1042 |
+
explicit Pad(
|
| 1043 |
+
Stream stream,
|
| 1044 |
+
const std::vector<int>& axes,
|
| 1045 |
+
const std::vector<int>& low_pad_size,
|
| 1046 |
+
const std::vector<int>& high_pad_size)
|
| 1047 |
+
: Primitive(stream),
|
| 1048 |
+
axes_(axes),
|
| 1049 |
+
low_pad_size_(low_pad_size),
|
| 1050 |
+
high_pad_size_(high_pad_size){};
|
| 1051 |
+
|
| 1052 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1053 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1054 |
+
|
| 1055 |
+
std::pair<array, int> vmap(
|
| 1056 |
+
const std::vector<array>& inputs,
|
| 1057 |
+
const std::vector<int>& axes) override;
|
| 1058 |
+
|
| 1059 |
+
DEFINE_GRADS()
|
| 1060 |
+
DEFINE_PRINT(Pad)
|
| 1061 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1062 |
+
|
| 1063 |
+
private:
|
| 1064 |
+
std::vector<int> axes_;
|
| 1065 |
+
std::vector<int> low_pad_size_;
|
| 1066 |
+
std::vector<int> high_pad_size_;
|
| 1067 |
+
|
| 1068 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1069 |
+
};
|
| 1070 |
+
|
| 1071 |
+
class Partition : public Primitive {
|
| 1072 |
+
public:
|
| 1073 |
+
explicit Partition(Stream stream, int kth, int axis)
|
| 1074 |
+
: Primitive(stream), kth_(kth), axis_(axis){};
|
| 1075 |
+
|
| 1076 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1077 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1078 |
+
|
| 1079 |
+
std::pair<array, int> vmap(
|
| 1080 |
+
const std::vector<array>& inputs,
|
| 1081 |
+
const std::vector<int>& axes) override;
|
| 1082 |
+
|
| 1083 |
+
DEFINE_GRADS()
|
| 1084 |
+
DEFINE_PRINT(Partition)
|
| 1085 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1086 |
+
|
| 1087 |
+
private:
|
| 1088 |
+
int kth_;
|
| 1089 |
+
int axis_;
|
| 1090 |
+
|
| 1091 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1092 |
+
};
|
| 1093 |
+
|
| 1094 |
+
class Power : public Primitive {
|
| 1095 |
+
public:
|
| 1096 |
+
explicit Power(Stream stream) : Primitive(stream){};
|
| 1097 |
+
|
| 1098 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1099 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1100 |
+
|
| 1101 |
+
std::pair<array, int> vmap(
|
| 1102 |
+
const std::vector<array>& inputs,
|
| 1103 |
+
const std::vector<int>& axes) override;
|
| 1104 |
+
|
| 1105 |
+
DEFINE_GRADS()
|
| 1106 |
+
DEFINE_PRINT(Power)
|
| 1107 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1108 |
+
|
| 1109 |
+
private:
|
| 1110 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1111 |
+
};
|
| 1112 |
+
|
| 1113 |
+
class QuantizedMatmul : public Primitive {
|
| 1114 |
+
public:
|
| 1115 |
+
explicit QuantizedMatmul(
|
| 1116 |
+
Stream stream,
|
| 1117 |
+
int group_size,
|
| 1118 |
+
int bits,
|
| 1119 |
+
bool transpose)
|
| 1120 |
+
: Primitive(stream),
|
| 1121 |
+
group_size_(group_size),
|
| 1122 |
+
bits_(bits),
|
| 1123 |
+
transpose_(transpose){};
|
| 1124 |
+
|
| 1125 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1126 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1127 |
+
|
| 1128 |
+
std::pair<array, int> vmap(
|
| 1129 |
+
const std::vector<array>& inputs,
|
| 1130 |
+
const std::vector<int>& axes) override;
|
| 1131 |
+
|
| 1132 |
+
DEFINE_GRADS()
|
| 1133 |
+
DEFINE_PRINT(QuantizedMatmul)
|
| 1134 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1135 |
+
|
| 1136 |
+
private:
|
| 1137 |
+
int group_size_;
|
| 1138 |
+
int bits_;
|
| 1139 |
+
bool transpose_;
|
| 1140 |
+
|
| 1141 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1142 |
+
};
|
| 1143 |
+
|
| 1144 |
+
class RandomBits : public Primitive {
|
| 1145 |
+
public:
|
| 1146 |
+
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
| 1147 |
+
: Primitive(stream), shape_(shape), width_(width){};
|
| 1148 |
+
|
| 1149 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1150 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1151 |
+
|
| 1152 |
+
std::pair<array, int> vmap(
|
| 1153 |
+
const std::vector<array>& inputs,
|
| 1154 |
+
const std::vector<int>& axes) override;
|
| 1155 |
+
|
| 1156 |
+
DEFINE_PRINT(RandomBits)
|
| 1157 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1158 |
+
|
| 1159 |
+
private:
|
| 1160 |
+
std::vector<int> shape_;
|
| 1161 |
+
int width_;
|
| 1162 |
+
|
| 1163 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1164 |
+
};
|
| 1165 |
+
|
| 1166 |
+
class Reshape : public Primitive {
|
| 1167 |
+
public:
|
| 1168 |
+
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
| 1169 |
+
: Primitive(stream), shape_(shape){};
|
| 1170 |
+
|
| 1171 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1172 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1173 |
+
|
| 1174 |
+
std::pair<array, int> vmap(
|
| 1175 |
+
const std::vector<array>& inputs,
|
| 1176 |
+
const std::vector<int>& axes) override;
|
| 1177 |
+
|
| 1178 |
+
DEFINE_GRADS()
|
| 1179 |
+
DEFINE_PRINT(Reshape)
|
| 1180 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1181 |
+
|
| 1182 |
+
private:
|
| 1183 |
+
std::vector<int> shape_;
|
| 1184 |
+
|
| 1185 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1186 |
+
};
|
| 1187 |
+
|
| 1188 |
+
class Reduce : public Primitive {
|
| 1189 |
+
public:
|
| 1190 |
+
enum ReduceType { And, Or, Sum, Prod, Min, Max };
|
| 1191 |
+
|
| 1192 |
+
explicit Reduce(
|
| 1193 |
+
Stream stream,
|
| 1194 |
+
ReduceType reduce_type,
|
| 1195 |
+
const std::vector<int>& axes)
|
| 1196 |
+
: Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
| 1197 |
+
|
| 1198 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1199 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1200 |
+
|
| 1201 |
+
std::pair<array, int> vmap(
|
| 1202 |
+
const std::vector<array>& inputs,
|
| 1203 |
+
const std::vector<int>& axes) override;
|
| 1204 |
+
std::vector<array> vjp(
|
| 1205 |
+
const std::vector<array>& primals,
|
| 1206 |
+
const array& cotan,
|
| 1207 |
+
const std::vector<int>& argnums) override;
|
| 1208 |
+
|
| 1209 |
+
void print(std::ostream& os) override {
|
| 1210 |
+
switch (reduce_type_) {
|
| 1211 |
+
case And:
|
| 1212 |
+
os << "And";
|
| 1213 |
+
case Or:
|
| 1214 |
+
os << "And";
|
| 1215 |
+
break;
|
| 1216 |
+
case Sum:
|
| 1217 |
+
os << "Sum";
|
| 1218 |
+
break;
|
| 1219 |
+
case Prod:
|
| 1220 |
+
os << "Prod";
|
| 1221 |
+
break;
|
| 1222 |
+
case Min:
|
| 1223 |
+
os << "Min";
|
| 1224 |
+
break;
|
| 1225 |
+
case Max:
|
| 1226 |
+
os << "Max";
|
| 1227 |
+
break;
|
| 1228 |
+
}
|
| 1229 |
+
os << " Reduce";
|
| 1230 |
+
}
|
| 1231 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1232 |
+
|
| 1233 |
+
private:
|
| 1234 |
+
ReduceType reduce_type_;
|
| 1235 |
+
std::vector<int> axes_;
|
| 1236 |
+
|
| 1237 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1238 |
+
};
|
| 1239 |
+
|
| 1240 |
+
class Round : public Primitive {
|
| 1241 |
+
public:
|
| 1242 |
+
explicit Round(Stream stream) : Primitive(stream){};
|
| 1243 |
+
|
| 1244 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1245 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1246 |
+
|
| 1247 |
+
std::pair<array, int> vmap(
|
| 1248 |
+
const std::vector<array>& inputs,
|
| 1249 |
+
const std::vector<int>& axes) override;
|
| 1250 |
+
|
| 1251 |
+
DEFINE_GRADS()
|
| 1252 |
+
DEFINE_PRINT(Round)
|
| 1253 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1254 |
+
|
| 1255 |
+
private:
|
| 1256 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1257 |
+
};
|
| 1258 |
+
|
| 1259 |
+
class Scan : public Primitive {
|
| 1260 |
+
public:
|
| 1261 |
+
enum ReduceType { Max, Min, Sum, Prod };
|
| 1262 |
+
|
| 1263 |
+
explicit Scan(
|
| 1264 |
+
Stream stream,
|
| 1265 |
+
ReduceType reduce_type,
|
| 1266 |
+
int axis,
|
| 1267 |
+
bool reverse,
|
| 1268 |
+
bool inclusive)
|
| 1269 |
+
: Primitive(stream),
|
| 1270 |
+
reduce_type_(reduce_type),
|
| 1271 |
+
axis_(axis),
|
| 1272 |
+
reverse_(reverse),
|
| 1273 |
+
inclusive_(inclusive){};
|
| 1274 |
+
|
| 1275 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1276 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1277 |
+
|
| 1278 |
+
std::pair<array, int> vmap(
|
| 1279 |
+
const std::vector<array>& inputs,
|
| 1280 |
+
const std::vector<int>& axes) override;
|
| 1281 |
+
|
| 1282 |
+
DEFINE_GRADS();
|
| 1283 |
+
void print(std::ostream& os) override {
|
| 1284 |
+
os << "Cum";
|
| 1285 |
+
switch (reduce_type_) {
|
| 1286 |
+
case Sum:
|
| 1287 |
+
os << "Sum";
|
| 1288 |
+
break;
|
| 1289 |
+
case Prod:
|
| 1290 |
+
os << "Prod";
|
| 1291 |
+
break;
|
| 1292 |
+
case Min:
|
| 1293 |
+
os << "Min";
|
| 1294 |
+
break;
|
| 1295 |
+
case Max:
|
| 1296 |
+
os << "Max";
|
| 1297 |
+
break;
|
| 1298 |
+
}
|
| 1299 |
+
os << " Reduce";
|
| 1300 |
+
}
|
| 1301 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1302 |
+
|
| 1303 |
+
private:
|
| 1304 |
+
ReduceType reduce_type_;
|
| 1305 |
+
int axis_;
|
| 1306 |
+
bool reverse_;
|
| 1307 |
+
bool inclusive_;
|
| 1308 |
+
|
| 1309 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1310 |
+
};
|
| 1311 |
+
|
| 1312 |
+
class Scatter : public Primitive {
|
| 1313 |
+
public:
|
| 1314 |
+
enum ReduceType { Max, Min, Sum, Prod, None };
|
| 1315 |
+
|
| 1316 |
+
explicit Scatter(
|
| 1317 |
+
Stream stream,
|
| 1318 |
+
ReduceType reduce_type,
|
| 1319 |
+
const std::vector<int>& axes)
|
| 1320 |
+
: Primitive(stream), reduce_type_(reduce_type), axes_(axes){};
|
| 1321 |
+
|
| 1322 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1323 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1324 |
+
|
| 1325 |
+
DEFINE_PRINT(Scatter)
|
| 1326 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1327 |
+
|
| 1328 |
+
private:
|
| 1329 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1330 |
+
ReduceType reduce_type_;
|
| 1331 |
+
std::vector<int> axes_;
|
| 1332 |
+
};
|
| 1333 |
+
|
| 1334 |
+
class Sigmoid : public Primitive {
|
| 1335 |
+
public:
|
| 1336 |
+
explicit Sigmoid(Stream stream) : Primitive(stream){};
|
| 1337 |
+
|
| 1338 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1339 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1340 |
+
|
| 1341 |
+
std::pair<array, int> vmap(
|
| 1342 |
+
const std::vector<array>& inputs,
|
| 1343 |
+
const std::vector<int>& axes) override;
|
| 1344 |
+
|
| 1345 |
+
DEFINE_GRADS()
|
| 1346 |
+
DEFINE_PRINT(Sigmoid)
|
| 1347 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1348 |
+
|
| 1349 |
+
private:
|
| 1350 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1351 |
+
};
|
| 1352 |
+
|
| 1353 |
+
class Sign : public Primitive {
|
| 1354 |
+
public:
|
| 1355 |
+
explicit Sign(Stream stream) : Primitive(stream){};
|
| 1356 |
+
|
| 1357 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1358 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1359 |
+
|
| 1360 |
+
std::pair<array, int> vmap(
|
| 1361 |
+
const std::vector<array>& inputs,
|
| 1362 |
+
const std::vector<int>& axes) override;
|
| 1363 |
+
|
| 1364 |
+
DEFINE_GRADS()
|
| 1365 |
+
DEFINE_PRINT(Sign)
|
| 1366 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1367 |
+
|
| 1368 |
+
private:
|
| 1369 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1370 |
+
};
|
| 1371 |
+
|
| 1372 |
+
class Sin : public Primitive {
|
| 1373 |
+
public:
|
| 1374 |
+
explicit Sin(Stream stream) : Primitive(stream){};
|
| 1375 |
+
|
| 1376 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1377 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1378 |
+
|
| 1379 |
+
std::pair<array, int> vmap(
|
| 1380 |
+
const std::vector<array>& inputs,
|
| 1381 |
+
const std::vector<int>& axes) override;
|
| 1382 |
+
|
| 1383 |
+
DEFINE_GRADS()
|
| 1384 |
+
DEFINE_PRINT(Sin)
|
| 1385 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1386 |
+
|
| 1387 |
+
private:
|
| 1388 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1389 |
+
};
|
| 1390 |
+
|
| 1391 |
+
class Sinh : public Primitive {
|
| 1392 |
+
public:
|
| 1393 |
+
explicit Sinh(Stream stream) : Primitive(stream){};
|
| 1394 |
+
|
| 1395 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1396 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1397 |
+
|
| 1398 |
+
std::pair<array, int> vmap(
|
| 1399 |
+
const std::vector<array>& inputs,
|
| 1400 |
+
const std::vector<int>& axes) override;
|
| 1401 |
+
|
| 1402 |
+
DEFINE_GRADS()
|
| 1403 |
+
DEFINE_PRINT(Sinh)
|
| 1404 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1405 |
+
|
| 1406 |
+
private:
|
| 1407 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1408 |
+
};
|
| 1409 |
+
|
| 1410 |
+
class Slice : public Primitive {
|
| 1411 |
+
public:
|
| 1412 |
+
explicit Slice(
|
| 1413 |
+
Stream stream,
|
| 1414 |
+
const std::vector<int>& start_indices,
|
| 1415 |
+
const std::vector<int>& end_indices,
|
| 1416 |
+
const std::vector<int>& strides)
|
| 1417 |
+
: Primitive(stream),
|
| 1418 |
+
start_indices_(start_indices),
|
| 1419 |
+
end_indices_(end_indices),
|
| 1420 |
+
strides_(strides){};
|
| 1421 |
+
|
| 1422 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1423 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1424 |
+
|
| 1425 |
+
std::pair<array, int> vmap(
|
| 1426 |
+
const std::vector<array>& inputs,
|
| 1427 |
+
const std::vector<int>& axes) override;
|
| 1428 |
+
|
| 1429 |
+
DEFINE_GRADS()
|
| 1430 |
+
DEFINE_PRINT(Slice)
|
| 1431 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1432 |
+
|
| 1433 |
+
private:
|
| 1434 |
+
std::vector<int> start_indices_;
|
| 1435 |
+
std::vector<int> end_indices_;
|
| 1436 |
+
std::vector<int> strides_;
|
| 1437 |
+
|
| 1438 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1439 |
+
};
|
| 1440 |
+
|
| 1441 |
+
class Softmax : public Primitive {
|
| 1442 |
+
public:
|
| 1443 |
+
explicit Softmax(Stream stream) : Primitive(stream){};
|
| 1444 |
+
|
| 1445 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1446 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1447 |
+
|
| 1448 |
+
std::pair<array, int> vmap(
|
| 1449 |
+
const std::vector<array>& inputs,
|
| 1450 |
+
const std::vector<int>& axes) override;
|
| 1451 |
+
|
| 1452 |
+
DEFINE_GRADS()
|
| 1453 |
+
DEFINE_PRINT(Softmax)
|
| 1454 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1455 |
+
|
| 1456 |
+
private:
|
| 1457 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1458 |
+
};
|
| 1459 |
+
|
| 1460 |
+
class Sort : public Primitive {
|
| 1461 |
+
public:
|
| 1462 |
+
explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){};
|
| 1463 |
+
|
| 1464 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1465 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1466 |
+
|
| 1467 |
+
std::pair<array, int> vmap(
|
| 1468 |
+
const std::vector<array>& inputs,
|
| 1469 |
+
const std::vector<int>& axes) override;
|
| 1470 |
+
|
| 1471 |
+
DEFINE_GRADS()
|
| 1472 |
+
DEFINE_PRINT(Sort)
|
| 1473 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1474 |
+
|
| 1475 |
+
private:
|
| 1476 |
+
int axis_;
|
| 1477 |
+
|
| 1478 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1479 |
+
};
|
| 1480 |
+
|
| 1481 |
+
class Square : public Primitive {
|
| 1482 |
+
public:
|
| 1483 |
+
explicit Square(Stream stream) : Primitive(stream){};
|
| 1484 |
+
|
| 1485 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1486 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1487 |
+
|
| 1488 |
+
std::pair<array, int> vmap(
|
| 1489 |
+
const std::vector<array>& inputs,
|
| 1490 |
+
const std::vector<int>& axes) override;
|
| 1491 |
+
|
| 1492 |
+
DEFINE_GRADS()
|
| 1493 |
+
DEFINE_PRINT(Square)
|
| 1494 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1495 |
+
|
| 1496 |
+
private:
|
| 1497 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1498 |
+
};
|
| 1499 |
+
|
| 1500 |
+
class Sqrt : public Primitive {
|
| 1501 |
+
public:
|
| 1502 |
+
explicit Sqrt(Stream stream, bool recip = false)
|
| 1503 |
+
: Primitive(stream), recip_(recip){};
|
| 1504 |
+
|
| 1505 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1506 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1507 |
+
|
| 1508 |
+
std::pair<array, int> vmap(
|
| 1509 |
+
const std::vector<array>& inputs,
|
| 1510 |
+
const std::vector<int>& axes) override;
|
| 1511 |
+
|
| 1512 |
+
DEFINE_GRADS()
|
| 1513 |
+
DEFINE_PRINT(Sqrt)
|
| 1514 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1515 |
+
|
| 1516 |
+
private:
|
| 1517 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1518 |
+
bool recip_;
|
| 1519 |
+
};
|
| 1520 |
+
|
| 1521 |
+
class StopGradient : public Primitive {
|
| 1522 |
+
public:
|
| 1523 |
+
explicit StopGradient(Stream stream) : Primitive(stream){};
|
| 1524 |
+
|
| 1525 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1526 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1527 |
+
|
| 1528 |
+
std::pair<array, int> vmap(
|
| 1529 |
+
const std::vector<array>& inputs,
|
| 1530 |
+
const std::vector<int>& axes) override;
|
| 1531 |
+
|
| 1532 |
+
DEFINE_PRINT(StopGradient)
|
| 1533 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1534 |
+
|
| 1535 |
+
private:
|
| 1536 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1537 |
+
};
|
| 1538 |
+
|
| 1539 |
+
class Subtract : public Primitive {
|
| 1540 |
+
public:
|
| 1541 |
+
explicit Subtract(Stream stream) : Primitive(stream){};
|
| 1542 |
+
|
| 1543 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1544 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1545 |
+
|
| 1546 |
+
std::pair<array, int> vmap(
|
| 1547 |
+
const std::vector<array>& inputs,
|
| 1548 |
+
const std::vector<int>& axes) override;
|
| 1549 |
+
|
| 1550 |
+
DEFINE_GRADS()
|
| 1551 |
+
DEFINE_PRINT(Subtract)
|
| 1552 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1553 |
+
|
| 1554 |
+
private:
|
| 1555 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1556 |
+
};
|
| 1557 |
+
|
| 1558 |
+
class Tan : public Primitive {
|
| 1559 |
+
public:
|
| 1560 |
+
explicit Tan(Stream stream) : Primitive(stream){};
|
| 1561 |
+
|
| 1562 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1563 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1564 |
+
|
| 1565 |
+
std::pair<array, int> vmap(
|
| 1566 |
+
const std::vector<array>& inputs,
|
| 1567 |
+
const std::vector<int>& axes) override;
|
| 1568 |
+
|
| 1569 |
+
DEFINE_GRADS()
|
| 1570 |
+
DEFINE_PRINT(Tan)
|
| 1571 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1572 |
+
|
| 1573 |
+
private:
|
| 1574 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1575 |
+
};
|
| 1576 |
+
|
| 1577 |
+
class Tanh : public Primitive {
|
| 1578 |
+
public:
|
| 1579 |
+
explicit Tanh(Stream stream) : Primitive(stream){};
|
| 1580 |
+
|
| 1581 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1582 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1583 |
+
|
| 1584 |
+
std::pair<array, int> vmap(
|
| 1585 |
+
const std::vector<array>& inputs,
|
| 1586 |
+
const std::vector<int>& axes) override;
|
| 1587 |
+
|
| 1588 |
+
DEFINE_GRADS()
|
| 1589 |
+
DEFINE_PRINT(Tanh)
|
| 1590 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1591 |
+
|
| 1592 |
+
private:
|
| 1593 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1594 |
+
};
|
| 1595 |
+
|
| 1596 |
+
class Uniform : public Primitive {
|
| 1597 |
+
public:
|
| 1598 |
+
explicit Uniform(Stream stream) : Primitive(stream){};
|
| 1599 |
+
|
| 1600 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1601 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1602 |
+
|
| 1603 |
+
std::pair<array, int> vmap(
|
| 1604 |
+
const std::vector<array>& inputs,
|
| 1605 |
+
const std::vector<int>& axes) override;
|
| 1606 |
+
|
| 1607 |
+
DEFINE_PRINT(Uniform)
|
| 1608 |
+
DEFINE_DEFAULT_IS_EQUIVALENT()
|
| 1609 |
+
|
| 1610 |
+
private:
|
| 1611 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1612 |
+
};
|
| 1613 |
+
|
| 1614 |
+
class Transpose : public Primitive {
|
| 1615 |
+
public:
|
| 1616 |
+
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
| 1617 |
+
: Primitive(stream), axes_(axes){};
|
| 1618 |
+
|
| 1619 |
+
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
| 1620 |
+
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
| 1621 |
+
|
| 1622 |
+
std::pair<array, int> vmap(
|
| 1623 |
+
const std::vector<array>& inputs,
|
| 1624 |
+
const std::vector<int>& axes) override;
|
| 1625 |
+
|
| 1626 |
+
DEFINE_GRADS()
|
| 1627 |
+
DEFINE_PRINT(Transpose)
|
| 1628 |
+
bool is_equivalent(const Primitive& other) const override;
|
| 1629 |
+
|
| 1630 |
+
private:
|
| 1631 |
+
std::vector<int> axes_;
|
| 1632 |
+
|
| 1633 |
+
void eval(const std::vector<array>& inputs, array& out);
|
| 1634 |
+
};
|
| 1635 |
+
|
| 1636 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/random.h
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <optional>
|
| 6 |
+
|
| 7 |
+
#include "mlx/array.h"
|
| 8 |
+
#include "mlx/stream.h"
|
| 9 |
+
|
| 10 |
+
namespace mlx::core::random {
|
| 11 |
+
|
| 12 |
+
class KeySequence {
|
| 13 |
+
public:
|
| 14 |
+
explicit KeySequence(uint64_t seed);
|
| 15 |
+
|
| 16 |
+
void seed(uint64_t seed);
|
| 17 |
+
array next();
|
| 18 |
+
|
| 19 |
+
// static default
|
| 20 |
+
static KeySequence& default_() {
|
| 21 |
+
static KeySequence ks(0);
|
| 22 |
+
return ks;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
private:
|
| 26 |
+
array key_;
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
/** Get a PRNG key from a seed. */
|
| 30 |
+
array key(uint64_t seed);
|
| 31 |
+
|
| 32 |
+
/** Seed the default PRNG key. */
|
| 33 |
+
void seed(uint64_t seed);
|
| 34 |
+
|
| 35 |
+
/** Generate an array with type uint32 filled with random bits. */
|
| 36 |
+
array bits(
|
| 37 |
+
const std::vector<int>& shape,
|
| 38 |
+
int width,
|
| 39 |
+
const std::optional<array>& key = std::nullopt,
|
| 40 |
+
StreamOrDevice s = {});
|
| 41 |
+
inline array bits(
|
| 42 |
+
const std::vector<int>& shape,
|
| 43 |
+
const std::optional<array>& key = std::nullopt,
|
| 44 |
+
StreamOrDevice s = {}) {
|
| 45 |
+
return bits(shape, 4, key, s);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
/** Split the rng key into a pair of keys. */
|
| 49 |
+
std::pair<array, array> split(const array& key, StreamOrDevice s = {});
|
| 50 |
+
|
| 51 |
+
/** Split the rng key into `num` keys. */
|
| 52 |
+
array split(const array& key, int num, StreamOrDevice s = {});
|
| 53 |
+
|
| 54 |
+
/** Generate uniform random numbers between low and high. */
|
| 55 |
+
array uniform(
|
| 56 |
+
const array& low,
|
| 57 |
+
const array& high,
|
| 58 |
+
const std::vector<int>& shape,
|
| 59 |
+
Dtype dtype = float32,
|
| 60 |
+
const std::optional<array>& key = std::nullopt,
|
| 61 |
+
StreamOrDevice s = {});
|
| 62 |
+
|
| 63 |
+
template <typename T, typename U>
|
| 64 |
+
array uniform(
|
| 65 |
+
T low,
|
| 66 |
+
U high,
|
| 67 |
+
const std::vector<int>& shape,
|
| 68 |
+
Dtype dtype = float32,
|
| 69 |
+
const std::optional<array>& key = std::nullopt,
|
| 70 |
+
StreamOrDevice s = {}) {
|
| 71 |
+
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/** Generate uniform random numbers between 0 and 1. */
|
| 75 |
+
array uniform(
|
| 76 |
+
const std::vector<int>& shape,
|
| 77 |
+
Dtype dtype,
|
| 78 |
+
const std::optional<array>& key = std::nullopt,
|
| 79 |
+
StreamOrDevice s = {});
|
| 80 |
+
inline array uniform(
|
| 81 |
+
const std::vector<int>& shape,
|
| 82 |
+
const std::optional<array>& key = std::nullopt,
|
| 83 |
+
StreamOrDevice s = {}) {
|
| 84 |
+
return uniform(shape, float32, key);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
/** Generate samples from the standard normal distribution. */
|
| 88 |
+
array normal(
|
| 89 |
+
const std::vector<int>& shape,
|
| 90 |
+
Dtype dtype,
|
| 91 |
+
const std::optional<array>& key = std::nullopt,
|
| 92 |
+
StreamOrDevice s = {});
|
| 93 |
+
inline array normal(
|
| 94 |
+
const std::vector<int>& shape,
|
| 95 |
+
const std::optional<array>& key = std::nullopt,
|
| 96 |
+
StreamOrDevice s = {}) {
|
| 97 |
+
return normal(shape, float32, key, s);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
/** Generate integer samples uniformly at random */
|
| 101 |
+
array randint(
|
| 102 |
+
const array& low,
|
| 103 |
+
const array& high,
|
| 104 |
+
const std::vector<int>& shape,
|
| 105 |
+
Dtype dtype = int32,
|
| 106 |
+
const std::optional<array>& key = std::nullopt,
|
| 107 |
+
StreamOrDevice s = {});
|
| 108 |
+
|
| 109 |
+
template <typename T, typename U>
|
| 110 |
+
array randint(
|
| 111 |
+
T low,
|
| 112 |
+
U high,
|
| 113 |
+
const std::vector<int>& shape,
|
| 114 |
+
Dtype dtype = int32,
|
| 115 |
+
const std::optional<array>& key = std::nullopt,
|
| 116 |
+
StreamOrDevice s = {}) {
|
| 117 |
+
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
|
| 118 |
+
};
|
| 119 |
+
|
| 120 |
+
/** Generate binary variables with probability to be true equal to p */
|
| 121 |
+
array bernoulli(
|
| 122 |
+
const array& p,
|
| 123 |
+
const std::vector<int>& shape,
|
| 124 |
+
const std::optional<array>& key = std::nullopt,
|
| 125 |
+
StreamOrDevice s = {});
|
| 126 |
+
array bernoulli(
|
| 127 |
+
const array& p,
|
| 128 |
+
const std::optional<array>& key = std::nullopt,
|
| 129 |
+
StreamOrDevice s = {});
|
| 130 |
+
|
| 131 |
+
template <typename T>
|
| 132 |
+
array bernoulli(
|
| 133 |
+
T p,
|
| 134 |
+
const std::optional<array>& key = std::nullopt,
|
| 135 |
+
StreamOrDevice s = {}) {
|
| 136 |
+
return bernoulli(array(p), key, s);
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
template <typename T>
|
| 140 |
+
array bernoulli(
|
| 141 |
+
T p,
|
| 142 |
+
const std::vector<int>& shape,
|
| 143 |
+
const std::optional<array>& key = std::nullopt,
|
| 144 |
+
StreamOrDevice s = {}) {
|
| 145 |
+
return bernoulli(array(p), shape, key, s);
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
array bernoulli(
|
| 149 |
+
const std::optional<array>& key = std::nullopt,
|
| 150 |
+
StreamOrDevice s = {});
|
| 151 |
+
|
| 152 |
+
array truncated_normal(
|
| 153 |
+
const array& lower,
|
| 154 |
+
const array& upper,
|
| 155 |
+
const std::vector<int>& shape,
|
| 156 |
+
Dtype dtype = float32,
|
| 157 |
+
const std::optional<array>& key = std::nullopt,
|
| 158 |
+
StreamOrDevice s = {});
|
| 159 |
+
|
| 160 |
+
array truncated_normal(
|
| 161 |
+
const array& lower,
|
| 162 |
+
const array& upper,
|
| 163 |
+
Dtype dtype = float32,
|
| 164 |
+
const std::optional<array>& key = std::nullopt,
|
| 165 |
+
StreamOrDevice s = {});
|
| 166 |
+
|
| 167 |
+
array gumbel(
|
| 168 |
+
const std::vector<int>& shape,
|
| 169 |
+
Dtype dtype = float32,
|
| 170 |
+
const std::optional<array>& key = std::nullopt,
|
| 171 |
+
StreamOrDevice s = {});
|
| 172 |
+
|
| 173 |
+
array categorical(
|
| 174 |
+
const array& logits,
|
| 175 |
+
int axis,
|
| 176 |
+
const std::vector<int>& shape,
|
| 177 |
+
const std::optional<array>& key = std::nullopt,
|
| 178 |
+
StreamOrDevice s = {});
|
| 179 |
+
|
| 180 |
+
array categorical(
|
| 181 |
+
const array& logits_,
|
| 182 |
+
int axis,
|
| 183 |
+
int num_samples,
|
| 184 |
+
const std::optional<array>& key = std::nullopt,
|
| 185 |
+
StreamOrDevice s = {});
|
| 186 |
+
|
| 187 |
+
array categorical(
|
| 188 |
+
const array& logits,
|
| 189 |
+
int axis = -1,
|
| 190 |
+
const std::optional<array>& key = std::nullopt,
|
| 191 |
+
StreamOrDevice s = {});
|
| 192 |
+
|
| 193 |
+
} // namespace mlx::core::random
|
lib/python3.11/site-packages/mlx/include/mlx/scheduler.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <atomic>
|
| 6 |
+
#include <future>
|
| 7 |
+
#include <queue>
|
| 8 |
+
#include <thread>
|
| 9 |
+
#include <unordered_map>
|
| 10 |
+
|
| 11 |
+
#include "mlx/backend/metal/metal.h"
|
| 12 |
+
#include "mlx/device.h"
|
| 13 |
+
#include "mlx/stream.h"
|
| 14 |
+
|
| 15 |
+
namespace mlx::core::scheduler {
|
| 16 |
+
|
| 17 |
+
struct StreamThread {
|
| 18 |
+
std::mutex mtx;
|
| 19 |
+
std::queue<std::function<void()>> q;
|
| 20 |
+
std::condition_variable cond;
|
| 21 |
+
bool stop;
|
| 22 |
+
Stream stream;
|
| 23 |
+
std::thread thread;
|
| 24 |
+
|
| 25 |
+
StreamThread(Stream stream)
|
| 26 |
+
: stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {}
|
| 27 |
+
|
| 28 |
+
~StreamThread() {
|
| 29 |
+
{
|
| 30 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 31 |
+
stop = true;
|
| 32 |
+
}
|
| 33 |
+
cond.notify_one();
|
| 34 |
+
thread.join();
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
void thread_fn() {
|
| 38 |
+
auto thread_pool = metal::new_scoped_memory_pool();
|
| 39 |
+
metal::new_stream(stream);
|
| 40 |
+
while (true) {
|
| 41 |
+
std::function<void()> task;
|
| 42 |
+
{
|
| 43 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 44 |
+
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
|
| 45 |
+
if (q.empty() && stop) {
|
| 46 |
+
return;
|
| 47 |
+
}
|
| 48 |
+
task = std::move(q.front());
|
| 49 |
+
q.pop();
|
| 50 |
+
}
|
| 51 |
+
task();
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
template <typename F>
|
| 56 |
+
void enqueue(F&& f) {
|
| 57 |
+
{
|
| 58 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 59 |
+
if (stop) {
|
| 60 |
+
throw std::runtime_error(
|
| 61 |
+
"Cannot enqueue work after stream is stopped.");
|
| 62 |
+
}
|
| 63 |
+
q.emplace(std::forward<F>(f));
|
| 64 |
+
}
|
| 65 |
+
cond.notify_one();
|
| 66 |
+
}
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
class Scheduler {
|
| 70 |
+
public:
|
| 71 |
+
Scheduler() : n_active_tasks_(0) {
|
| 72 |
+
if (metal::is_available()) {
|
| 73 |
+
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
| 74 |
+
}
|
| 75 |
+
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
// Not copyable or moveable
|
| 79 |
+
Scheduler(const Scheduler&) = delete;
|
| 80 |
+
Scheduler(Scheduler&&) = delete;
|
| 81 |
+
Scheduler& operator=(const Scheduler&) = delete;
|
| 82 |
+
Scheduler& operator=(Scheduler&&) = delete;
|
| 83 |
+
|
| 84 |
+
Stream new_stream(const Device& d) {
|
| 85 |
+
auto stream = Stream(streams_.size(), d);
|
| 86 |
+
streams_.push_back(new StreamThread{stream});
|
| 87 |
+
return stream;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template <typename F>
|
| 91 |
+
void enqueue(const Stream& stream, F&& f);
|
| 92 |
+
|
| 93 |
+
Stream get_default_stream(const Device& d) {
|
| 94 |
+
return default_streams_.at(d.type);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
void set_default_stream(const Stream& s) {
|
| 98 |
+
default_streams_.at(s.device.type) = s;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
void notify_new_task(const Stream& stream) {
|
| 102 |
+
{
|
| 103 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 104 |
+
n_active_tasks_++;
|
| 105 |
+
}
|
| 106 |
+
completion_cv.notify_all();
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
void notify_task_completion(const Stream& stream) {
|
| 110 |
+
{
|
| 111 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 112 |
+
n_active_tasks_--;
|
| 113 |
+
}
|
| 114 |
+
completion_cv.notify_all();
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
int n_active_tasks() const {
|
| 118 |
+
return n_active_tasks_;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
void wait_for_one() {
|
| 122 |
+
std::unique_lock<std::mutex> lk(mtx);
|
| 123 |
+
int n_tasks_old = n_active_tasks();
|
| 124 |
+
if (n_tasks_old > 1) {
|
| 125 |
+
completion_cv.wait(lk, [this, n_tasks_old] {
|
| 126 |
+
return this->n_active_tasks() != n_tasks_old;
|
| 127 |
+
});
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
~Scheduler() {
|
| 132 |
+
for (auto s : streams_) {
|
| 133 |
+
delete s;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
private:
|
| 138 |
+
int n_active_tasks_;
|
| 139 |
+
std::vector<StreamThread*> streams_;
|
| 140 |
+
std::unordered_map<Device::DeviceType, Stream> default_streams_;
|
| 141 |
+
std::condition_variable completion_cv;
|
| 142 |
+
std::mutex mtx;
|
| 143 |
+
};
|
| 144 |
+
|
| 145 |
+
template <typename F>
|
| 146 |
+
void Scheduler::enqueue(const Stream& stream, F&& f) {
|
| 147 |
+
streams_[stream.index]->enqueue(std::forward<F>(f));
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
Scheduler& scheduler();
|
| 151 |
+
|
| 152 |
+
template <typename F>
|
| 153 |
+
void enqueue(const Stream& stream, F&& f) {
|
| 154 |
+
scheduler().enqueue(stream, std::forward<F>(f));
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
inline int n_active_tasks() {
|
| 158 |
+
return scheduler().n_active_tasks();
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
inline void notify_new_task(const Stream& stream) {
|
| 162 |
+
scheduler().notify_new_task(stream);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
inline void notify_task_completion(const Stream& stream) {
|
| 166 |
+
scheduler().notify_task_completion(stream);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
inline void wait_for_one() {
|
| 170 |
+
scheduler().wait_for_one();
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
} // namespace mlx::core::scheduler
|
lib/python3.11/site-packages/mlx/include/mlx/stream.h
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "mlx/device.h"
|
| 6 |
+
|
| 7 |
+
namespace mlx::core {
|
| 8 |
+
|
| 9 |
+
struct Stream {
|
| 10 |
+
int index;
|
| 11 |
+
Device device;
|
| 12 |
+
explicit Stream(int index, Device device) : index(index), device(device) {}
|
| 13 |
+
};
|
| 14 |
+
|
| 15 |
+
/** Get the default stream for the given device. */
|
| 16 |
+
Stream default_stream(Device d);
|
| 17 |
+
|
| 18 |
+
/** Make the stream the default for its device. */
|
| 19 |
+
void set_default_stream(Stream s);
|
| 20 |
+
|
| 21 |
+
/** Make a new stream on the given device. */
|
| 22 |
+
Stream new_stream(Device d);
|
| 23 |
+
|
| 24 |
+
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
| 25 |
+
return lhs.index == rhs.index;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
| 29 |
+
return !(lhs == rhs);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/transforms.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "array.h"
|
| 6 |
+
|
| 7 |
+
namespace mlx::core {
|
| 8 |
+
|
| 9 |
+
/** Fuse equivalent arrays to avoid duplicate execution. */
|
| 10 |
+
void simplify(const std::vector<array>& outputs);
|
| 11 |
+
|
| 12 |
+
template <typename... Arrays>
|
| 13 |
+
void simplify(Arrays... outputs) {
|
| 14 |
+
simplify(std::vector<array>{std::forward<Arrays>(outputs)...});
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
void eval(const std::vector<array>& outputs, bool retain_graph = false);
|
| 18 |
+
|
| 19 |
+
template <typename... Arrays>
|
| 20 |
+
void eval(Arrays... outputs) {
|
| 21 |
+
eval(std::vector<array>{std::forward<Arrays>(outputs)...}, false);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/**
|
| 25 |
+
* Computes the output and vector-Jacobian product (VJP) of a function.
|
| 26 |
+
*
|
| 27 |
+
* Computes the vector-Jacobian product of the vector of cotangents with the
|
| 28 |
+
* Jacobian of the function evaluated at the primals. Returns a pair of
|
| 29 |
+
* vectors of output arrays and VJP arrays.
|
| 30 |
+
**/
|
| 31 |
+
std::pair<std::vector<array>, std::vector<array>> vjp(
|
| 32 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 33 |
+
const std::vector<array>& primals,
|
| 34 |
+
const std::vector<array>& cotangents);
|
| 35 |
+
|
| 36 |
+
/**
|
| 37 |
+
* Computes the output and vector-Jacobian product (VJP) of a unary function.
|
| 38 |
+
*/
|
| 39 |
+
std::pair<array, array> vjp(
|
| 40 |
+
const std::function<array(const array&)>& fun,
|
| 41 |
+
const array& primal,
|
| 42 |
+
const array& cotangent);
|
| 43 |
+
|
| 44 |
+
/**
|
| 45 |
+
* Computes the output and Jacobian-vector product (JVP) of a function.
|
| 46 |
+
*
|
| 47 |
+
* Computes the Jacobian-vector product of the Jacobian of the function
|
| 48 |
+
* evaluated at the primals with the vector of tangents. Returns a pair of
|
| 49 |
+
* vectors of output arrays and JVP arrays.
|
| 50 |
+
**/
|
| 51 |
+
std::pair<std::vector<array>, std::vector<array>> jvp(
|
| 52 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 53 |
+
const std::vector<array>& primals,
|
| 54 |
+
const std::vector<array>& tangents);
|
| 55 |
+
|
| 56 |
+
/**
|
| 57 |
+
* Computes the output and Jacobian-vector product (JVP) of a unary function.
|
| 58 |
+
*/
|
| 59 |
+
std::pair<array, array> jvp(
|
| 60 |
+
const std::function<array(const array&)>& fun,
|
| 61 |
+
const array& primal,
|
| 62 |
+
const array& tangent);
|
| 63 |
+
|
| 64 |
+
// Return type of general value_and_grad: a function which takes an input
|
| 65 |
+
// vector of arrays and returns a pair of vectors of arrays one for the
|
| 66 |
+
// values and one for the gradients wrt the first value.
|
| 67 |
+
using ValueAndGradFn =
|
| 68 |
+
std::function<std::pair<std::vector<array>, std::vector<array>>(
|
| 69 |
+
const std::vector<array>&)>;
|
| 70 |
+
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
|
| 71 |
+
const std::vector<array>&)>;
|
| 72 |
+
|
| 73 |
+
/**
|
| 74 |
+
* Returns a function which computes the value and gradient of the input
|
| 75 |
+
* function with respect to a vector of input arrays.
|
| 76 |
+
**/
|
| 77 |
+
ValueAndGradFn value_and_grad(
|
| 78 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 79 |
+
const std::vector<int>& argnums);
|
| 80 |
+
|
| 81 |
+
/**
|
| 82 |
+
* Returns a function which computes the value and gradient of the input
|
| 83 |
+
* function with respect to a single input array.
|
| 84 |
+
**/
|
| 85 |
+
ValueAndGradFn inline value_and_grad(
|
| 86 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 87 |
+
int argnum = 0) {
|
| 88 |
+
return value_and_grad(fun, std::vector<int>{argnum});
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/**
|
| 92 |
+
* Returns a function which computes the value and gradient of the unary
|
| 93 |
+
* input function.
|
| 94 |
+
**/
|
| 95 |
+
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
|
| 96 |
+
const std::function<array(const array&)>& fun) {
|
| 97 |
+
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
SimpleValueAndGradFn inline value_and_grad(
|
| 101 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
| 102 |
+
const std::vector<int>& argnums) {
|
| 103 |
+
return [fun, argnums](auto inputs) {
|
| 104 |
+
auto result = value_and_grad(
|
| 105 |
+
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
|
| 106 |
+
argnums)(inputs);
|
| 107 |
+
|
| 108 |
+
return std::make_pair(result.first[0], result.second);
|
| 109 |
+
};
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
SimpleValueAndGradFn inline value_and_grad(
|
| 113 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
| 114 |
+
int argnum = 0) {
|
| 115 |
+
return value_and_grad(fun, std::vector<int>{argnum});
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/**
|
| 119 |
+
* Returns a function which computes the gradient of the input function with
|
| 120 |
+
* respect to a vector of input arrays.
|
| 121 |
+
*
|
| 122 |
+
* The function being differentiated takes a vector of arrays and returns an
|
| 123 |
+
* array. The vector of `argnums` specifies which the arguments to compute
|
| 124 |
+
* the gradient with respect to. At least one argument must be specified.
|
| 125 |
+
**/
|
| 126 |
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
| 127 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
| 128 |
+
const std::vector<int>& argnums) {
|
| 129 |
+
auto fn = value_and_grad(fun, argnums);
|
| 130 |
+
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
/**
|
| 134 |
+
* Returns a function which computes the gradient of the input function with
|
| 135 |
+
* respect to a single input array.
|
| 136 |
+
*
|
| 137 |
+
* The function being differentiated takes a vector of arrays and returns an
|
| 138 |
+
* array. The optional `argnum` index specifies which the argument to compute
|
| 139 |
+
* the gradient with respect to and defaults to 0.
|
| 140 |
+
**/
|
| 141 |
+
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
|
| 142 |
+
const std::function<array(const std::vector<array>&)>& fun,
|
| 143 |
+
int argnum = 0) {
|
| 144 |
+
return grad(fun, std::vector<int>{argnum});
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/**
|
| 148 |
+
* Returns a function which computes the gradient of the unary input function.
|
| 149 |
+
**/
|
| 150 |
+
std::function<array(const array&)> inline grad(
|
| 151 |
+
const std::function<array(const array&)>& fun) {
|
| 152 |
+
auto fn = value_and_grad(fun);
|
| 153 |
+
return [fn](const array& input) { return fn(input).second; };
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/**
|
| 157 |
+
* Automatically vectorize a unary function over the requested axes.
|
| 158 |
+
*/
|
| 159 |
+
std::function<array(const array&)> vmap(
|
| 160 |
+
const std::function<array(const array&)>& fun,
|
| 161 |
+
int in_axis = 0,
|
| 162 |
+
int out_axis = 0);
|
| 163 |
+
|
| 164 |
+
/**
|
| 165 |
+
* Automatically vectorize a binary function over the requested axes.
|
| 166 |
+
*/
|
| 167 |
+
std::function<array(const array&, const array&)> vmap(
|
| 168 |
+
const std::function<array(const array&, const array&)>& fun,
|
| 169 |
+
int in_axis_a = 0,
|
| 170 |
+
int in_axis_b = 0,
|
| 171 |
+
int out_axis = 0);
|
| 172 |
+
|
| 173 |
+
/**
|
| 174 |
+
* Automatically vectorize a function over the requested axes.
|
| 175 |
+
*
|
| 176 |
+
* The input function to `vmap` takes as an argument a vector of arrays and
|
| 177 |
+
* returns a vector of arrays. Optionally specify the axes to vectorize over
|
| 178 |
+
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
|
| 179 |
+
* Returns a vectorized function with the same signature as the input
|
| 180 |
+
* function.
|
| 181 |
+
*/
|
| 182 |
+
std::function<std::vector<array>(const std::vector<array>&)> vmap(
|
| 183 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 184 |
+
const std::vector<int>& in_axes = {},
|
| 185 |
+
const std::vector<int>& out_axes = {});
|
| 186 |
+
|
| 187 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
namespace mlx::core::detail {
|
| 4 |
+
|
| 5 |
+
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
| 6 |
+
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
| 7 |
+
const std::vector<array>& inputs,
|
| 8 |
+
const std::vector<int>& in_axes);
|
| 9 |
+
|
| 10 |
+
std::vector<array> vmap_replace(
|
| 11 |
+
const std::vector<array>& inputs,
|
| 12 |
+
const std::vector<array>& s_inputs,
|
| 13 |
+
const std::vector<array>& s_outputs,
|
| 14 |
+
const std::vector<int>& in_axes,
|
| 15 |
+
const std::vector<int>& out_axes);
|
| 16 |
+
|
| 17 |
+
} // namespace mlx::core::detail
|
lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <cmath>
|
| 7 |
+
#include <cstdint>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#define __MLX_BFLOAT_NAN__ 0x7FC0
|
| 11 |
+
|
| 12 |
+
namespace mlx::core {
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
union float_bits_bf16 {
|
| 16 |
+
float f;
|
| 17 |
+
uint32_t u;
|
| 18 |
+
};
|
| 19 |
+
} // namespace
|
| 20 |
+
|
| 21 |
+
struct _MLX_BFloat16 {
|
| 22 |
+
uint16_t bits_;
|
| 23 |
+
|
| 24 |
+
// Default constructor
|
| 25 |
+
_MLX_BFloat16() = default;
|
| 26 |
+
|
| 27 |
+
// Default copy constructor
|
| 28 |
+
_MLX_BFloat16(_MLX_BFloat16 const&) = default;
|
| 29 |
+
|
| 30 |
+
// Appease std::vector<bool> for being special
|
| 31 |
+
_MLX_BFloat16& operator=(std::vector<bool>::reference x) {
|
| 32 |
+
bits_ = x;
|
| 33 |
+
return *this;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
_MLX_BFloat16& operator=(const float& x) {
|
| 37 |
+
return (*this = _MLX_BFloat16(x));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// From float32
|
| 41 |
+
_MLX_BFloat16(const float& x) {
|
| 42 |
+
if (std::isnan(x)) {
|
| 43 |
+
bits_ = __MLX_BFLOAT_NAN__;
|
| 44 |
+
} else {
|
| 45 |
+
// Union
|
| 46 |
+
float_bits_bf16 in;
|
| 47 |
+
|
| 48 |
+
// Take bits
|
| 49 |
+
in.f = x;
|
| 50 |
+
|
| 51 |
+
// Round to nearest even
|
| 52 |
+
in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
|
| 53 |
+
|
| 54 |
+
// Take upper 16 bits
|
| 55 |
+
bits_ = in.u >> 16;
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// To float32
|
| 60 |
+
operator float() const {
|
| 61 |
+
// Union
|
| 62 |
+
float_bits_bf16 out;
|
| 63 |
+
|
| 64 |
+
// Upper 16 bits are the data and lower 16 bits are 0s
|
| 65 |
+
out.u = ((uint32_t)bits_) << 16;
|
| 66 |
+
|
| 67 |
+
return out.f;
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
| 72 |
+
inline otype __operator__(atype lhs, btype rhs) { \
|
| 73 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
| 77 |
+
inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
| 78 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 79 |
+
} \
|
| 80 |
+
inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
| 81 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// Operators
|
| 85 |
+
#define bfloat_binop(_op_, _operator_) \
|
| 86 |
+
bfloat_binop_base( \
|
| 87 |
+
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
| 88 |
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
| 89 |
+
bfloat_binop_helper(_op_, _operator_, double, double, double); \
|
| 90 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
|
| 91 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
| 92 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
| 93 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
| 94 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
| 95 |
+
|
| 96 |
+
bfloat_binop(+, operator+);
|
| 97 |
+
bfloat_binop(-, operator-);
|
| 98 |
+
bfloat_binop(*, operator*);
|
| 99 |
+
bfloat_binop(/, operator/);
|
| 100 |
+
|
| 101 |
+
#undef bfloat_binop
|
| 102 |
+
|
| 103 |
+
// Comparison ops
|
| 104 |
+
#define bfloat_compop(__op__, __operator__) \
|
| 105 |
+
bfloat_binop_base( \
|
| 106 |
+
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
| 107 |
+
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
| 108 |
+
bfloat_binop_helper(__op__, __operator__, bool, double, double); \
|
| 109 |
+
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
| 110 |
+
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
| 111 |
+
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
| 112 |
+
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
| 113 |
+
|
| 114 |
+
bfloat_compop(>, operator>);
|
| 115 |
+
bfloat_compop(<, operator<);
|
| 116 |
+
bfloat_compop(>=, operator>=);
|
| 117 |
+
bfloat_compop(<=, operator<=);
|
| 118 |
+
bfloat_compop(==, operator==);
|
| 119 |
+
bfloat_compop(!=, operator!=);
|
| 120 |
+
|
| 121 |
+
#undef bfloat_compop
|
| 122 |
+
|
| 123 |
+
// Negative
|
| 124 |
+
inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
|
| 125 |
+
return -static_cast<float>(lhs);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
// Inplace ops
|
| 129 |
+
#define bfloat_inplace_op(__op__, __operator__) \
|
| 130 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
|
| 131 |
+
lhs = lhs __op__ rhs; \
|
| 132 |
+
return lhs; \
|
| 133 |
+
} \
|
| 134 |
+
inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
|
| 135 |
+
lhs = lhs __op__ rhs; \
|
| 136 |
+
return lhs; \
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
bfloat_inplace_op(+, operator+=);
|
| 140 |
+
bfloat_inplace_op(-, operator-=);
|
| 141 |
+
bfloat_inplace_op(*, operator*=);
|
| 142 |
+
bfloat_inplace_op(/, operator/=);
|
| 143 |
+
|
| 144 |
+
#undef bfloat_inplace_op
|
| 145 |
+
|
| 146 |
+
// Bitwise ops
|
| 147 |
+
|
| 148 |
+
#define bfloat_bitop(__op__, __operator__) \
|
| 149 |
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
|
| 150 |
+
_MLX_BFloat16 out; \
|
| 151 |
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
| 152 |
+
return out; \
|
| 153 |
+
} \
|
| 154 |
+
inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
|
| 155 |
+
_MLX_BFloat16 out; \
|
| 156 |
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
| 157 |
+
return out; \
|
| 158 |
+
} \
|
| 159 |
+
inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
|
| 160 |
+
_MLX_BFloat16 out; \
|
| 161 |
+
out.bits_ = lhs __op__ rhs.bits_; \
|
| 162 |
+
return out; \
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
bfloat_bitop(|, operator|);
|
| 166 |
+
bfloat_bitop(&, operator&);
|
| 167 |
+
bfloat_bitop(^, operator^);
|
| 168 |
+
|
| 169 |
+
#undef bfloat_bitop
|
| 170 |
+
|
| 171 |
+
#define bfloat_inplace_bitop(__op__, __operator__) \
|
| 172 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
| 173 |
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
| 174 |
+
return lhs; \
|
| 175 |
+
} \
|
| 176 |
+
inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
|
| 177 |
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
| 178 |
+
return lhs; \
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
bfloat_inplace_bitop(|, operator|=);
|
| 182 |
+
bfloat_inplace_bitop(&, operator&=);
|
| 183 |
+
bfloat_inplace_bitop(^, operator^=);
|
| 184 |
+
|
| 185 |
+
#undef bfloat_inplace_bitop
|
| 186 |
+
|
| 187 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/complex.h
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#include <complex>
|
| 5 |
+
#include "mlx/types/half_types.h"
|
| 6 |
+
|
| 7 |
+
namespace mlx::core {
|
| 8 |
+
|
| 9 |
+
struct complex64_t;
|
| 10 |
+
|
| 11 |
+
template <typename T>
|
| 12 |
+
static constexpr bool can_convert_to_complex64 =
|
| 13 |
+
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
| 14 |
+
|
| 15 |
+
struct complex64_t : public std::complex<float> {
|
| 16 |
+
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
| 17 |
+
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
| 18 |
+
|
| 19 |
+
template <
|
| 20 |
+
typename T,
|
| 21 |
+
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
| 22 |
+
complex64_t(T x) : std::complex<float>(x){};
|
| 23 |
+
|
| 24 |
+
operator float() const {
|
| 25 |
+
return real();
|
| 26 |
+
};
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
| 30 |
+
return (a.real() > b.real()) ||
|
| 31 |
+
(a.real() == b.real() && a.imag() >= b.imag());
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
| 35 |
+
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
| 39 |
+
return operator>=(b, a);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
| 43 |
+
return operator>(b, a);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
inline complex64_t operator-(const complex64_t& v) {
|
| 47 |
+
return -static_cast<std::complex<float>>(v);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// clang-format off
|
| 51 |
+
#define complex_binop_helper(_op_, _operator_, itype) \
|
| 52 |
+
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
| 53 |
+
return x _op_ static_cast<std::complex<float>>(y); \
|
| 54 |
+
} \
|
| 55 |
+
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
| 56 |
+
return static_cast<std::complex<float>>(x) _op_ y; \
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
#define complex_binop(_op_, _operator_) \
|
| 60 |
+
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
| 61 |
+
return static_cast<std::complex<float>>(x) \
|
| 62 |
+
_op_ static_cast<std::complex<float>>(y); \
|
| 63 |
+
} \
|
| 64 |
+
complex_binop_helper(_op_, _operator_, bool) \
|
| 65 |
+
complex_binop_helper(_op_, _operator_, uint32_t) \
|
| 66 |
+
complex_binop_helper(_op_, _operator_, uint64_t) \
|
| 67 |
+
complex_binop_helper(_op_, _operator_, int32_t) \
|
| 68 |
+
complex_binop_helper(_op_, _operator_, int64_t) \
|
| 69 |
+
complex_binop_helper(_op_, _operator_, float16_t) \
|
| 70 |
+
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
| 71 |
+
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
| 72 |
+
complex_binop_helper(_op_, _operator_, float)
|
| 73 |
+
// clang-format on
|
| 74 |
+
|
| 75 |
+
complex_binop(+, operator+)
|
| 76 |
+
|
| 77 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <cmath>
|
| 7 |
+
#include <cstdint>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#define __MLX_HALF_NAN__ 0x7D00
|
| 11 |
+
|
| 12 |
+
namespace mlx::core {
|
| 13 |
+
|
| 14 |
+
namespace {
|
| 15 |
+
union float_bits_fp16 {
|
| 16 |
+
float f;
|
| 17 |
+
uint32_t u;
|
| 18 |
+
};
|
| 19 |
+
} // namespace
|
| 20 |
+
|
| 21 |
+
struct _MLX_Float16 {
|
| 22 |
+
uint16_t bits_;
|
| 23 |
+
|
| 24 |
+
// Default constructor
|
| 25 |
+
_MLX_Float16() = default;
|
| 26 |
+
|
| 27 |
+
// Default copy constructor
|
| 28 |
+
_MLX_Float16(_MLX_Float16 const&) = default;
|
| 29 |
+
|
| 30 |
+
// Appease std::vector<bool> for being special
|
| 31 |
+
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
| 32 |
+
bits_ = x;
|
| 33 |
+
return *this;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
_MLX_Float16& operator=(const float& x) {
|
| 37 |
+
return (*this = _MLX_Float16(x));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
// From float32
|
| 41 |
+
_MLX_Float16(const float& x) : bits_(0) {
|
| 42 |
+
// Conversion following
|
| 43 |
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
| 44 |
+
|
| 45 |
+
// Union
|
| 46 |
+
float_bits_fp16 in;
|
| 47 |
+
|
| 48 |
+
// Take fp32 bits
|
| 49 |
+
in.f = x;
|
| 50 |
+
|
| 51 |
+
// Find and take sign bit
|
| 52 |
+
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
| 53 |
+
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
| 54 |
+
|
| 55 |
+
if (std::isnan(x)) {
|
| 56 |
+
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
| 57 |
+
} else {
|
| 58 |
+
// Union
|
| 59 |
+
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
| 60 |
+
|
| 61 |
+
// Find exponent bits and take the max supported by half
|
| 62 |
+
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
| 63 |
+
uint32_t max_expo_32 = uint32_t(0x38800000);
|
| 64 |
+
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
| 65 |
+
x_expo_32 += uint32_t(15) << 23;
|
| 66 |
+
|
| 67 |
+
// Handle scaling to inf as needed
|
| 68 |
+
inf_scale.u = uint32_t(0x77800000);
|
| 69 |
+
zero_scale.u = uint32_t(0x08800000);
|
| 70 |
+
|
| 71 |
+
// Combine with magic and let addition do rounding
|
| 72 |
+
magic_bits.u = x_expo_32;
|
| 73 |
+
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
| 74 |
+
|
| 75 |
+
// Take the lower 5 bits of the exponent
|
| 76 |
+
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
| 77 |
+
|
| 78 |
+
// Collect the lower 12 bits which have the mantissa
|
| 79 |
+
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
| 80 |
+
|
| 81 |
+
// Combine sign, exp and mantissa
|
| 82 |
+
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// To float32
|
| 87 |
+
operator float() const {
|
| 88 |
+
// Conversion following
|
| 89 |
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
| 90 |
+
|
| 91 |
+
// Union
|
| 92 |
+
float_bits_fp16 out;
|
| 93 |
+
|
| 94 |
+
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
| 95 |
+
uint32_t base = (bits_ << 16);
|
| 96 |
+
uint32_t two_base = base + base;
|
| 97 |
+
|
| 98 |
+
uint32_t denorm_max = 1u << 27;
|
| 99 |
+
if (two_base < denorm_max) {
|
| 100 |
+
out.u = uint32_t(126) << 23; // magic mask
|
| 101 |
+
out.u |= (two_base >> 17); // Bits from fp16
|
| 102 |
+
out.f -= 0.5f; // magic bias
|
| 103 |
+
} else {
|
| 104 |
+
out.u = uint32_t(0xE0) << 23; // exponent offset
|
| 105 |
+
out.u += (two_base >> 4); // Bits from fp16
|
| 106 |
+
float out_unscaled = out.f; // Store value
|
| 107 |
+
out.u = uint32_t(0x7800000); // exponent scale
|
| 108 |
+
out.f *= out_unscaled;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Add sign
|
| 112 |
+
out.u |= x_sign_32;
|
| 113 |
+
|
| 114 |
+
return out.f;
|
| 115 |
+
}
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
| 119 |
+
inline otype __operator__(atype lhs, btype rhs) { \
|
| 120 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
| 124 |
+
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
| 125 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 126 |
+
} \
|
| 127 |
+
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
| 128 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// Operators
|
| 132 |
+
#define half_binop(__op__, __operator__) \
|
| 133 |
+
half_binop_base( \
|
| 134 |
+
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
| 135 |
+
half_binop_helper(__op__, __operator__, float, float, float); \
|
| 136 |
+
half_binop_helper(__op__, __operator__, double, double, double); \
|
| 137 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
| 138 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
| 139 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
| 140 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
| 141 |
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
| 142 |
+
|
| 143 |
+
half_binop(+, operator+);
|
| 144 |
+
half_binop(-, operator-);
|
| 145 |
+
half_binop(*, operator*);
|
| 146 |
+
half_binop(/, operator/);
|
| 147 |
+
|
| 148 |
+
#undef half_binop
|
| 149 |
+
|
| 150 |
+
// Comparison ops
|
| 151 |
+
#define half_compop(__op__, __operator__) \
|
| 152 |
+
half_binop_base( \
|
| 153 |
+
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
| 154 |
+
half_binop_helper(__op__, __operator__, bool, float, float); \
|
| 155 |
+
half_binop_helper(__op__, __operator__, bool, double, double); \
|
| 156 |
+
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
| 157 |
+
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
| 158 |
+
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
| 159 |
+
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
| 160 |
+
|
| 161 |
+
half_compop(>, operator>);
|
| 162 |
+
half_compop(<, operator<);
|
| 163 |
+
half_compop(>=, operator>=);
|
| 164 |
+
half_compop(<=, operator<=);
|
| 165 |
+
half_compop(==, operator==);
|
| 166 |
+
half_compop(!=, operator!=);
|
| 167 |
+
|
| 168 |
+
#undef half_compop
|
| 169 |
+
|
| 170 |
+
// Negative
|
| 171 |
+
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
| 172 |
+
return -static_cast<float>(lhs);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Inplace ops
|
| 176 |
+
#define half_inplace_op(__op__, __operator__) \
|
| 177 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
| 178 |
+
lhs = lhs __op__ rhs; \
|
| 179 |
+
return lhs; \
|
| 180 |
+
} \
|
| 181 |
+
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
| 182 |
+
lhs = lhs __op__ rhs; \
|
| 183 |
+
return lhs; \
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
half_inplace_op(+, operator+=);
|
| 187 |
+
half_inplace_op(-, operator-=);
|
| 188 |
+
half_inplace_op(*, operator*=);
|
| 189 |
+
half_inplace_op(/, operator/=);
|
| 190 |
+
|
| 191 |
+
#undef half_inplace_op
|
| 192 |
+
|
| 193 |
+
// Bitwise ops
|
| 194 |
+
|
| 195 |
+
#define half_bitop(__op__, __operator__) \
|
| 196 |
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
| 197 |
+
_MLX_Float16 out; \
|
| 198 |
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
| 199 |
+
return out; \
|
| 200 |
+
} \
|
| 201 |
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
| 202 |
+
_MLX_Float16 out; \
|
| 203 |
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
| 204 |
+
return out; \
|
| 205 |
+
} \
|
| 206 |
+
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
| 207 |
+
_MLX_Float16 out; \
|
| 208 |
+
out.bits_ = lhs __op__ rhs.bits_; \
|
| 209 |
+
return out; \
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
half_bitop(|, operator|);
|
| 213 |
+
half_bitop(&, operator&);
|
| 214 |
+
half_bitop(^, operator^);
|
| 215 |
+
|
| 216 |
+
#undef half_bitop
|
| 217 |
+
|
| 218 |
+
#define half_inplace_bitop(__op__, __operator__) \
|
| 219 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
| 220 |
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
| 221 |
+
return lhs; \
|
| 222 |
+
} \
|
| 223 |
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
| 224 |
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
| 225 |
+
return lhs; \
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
half_inplace_bitop(|, operator|=);
|
| 229 |
+
half_inplace_bitop(&, operator&=);
|
| 230 |
+
half_inplace_bitop(^, operator^=);
|
| 231 |
+
|
| 232 |
+
#undef half_inplace_bitop
|
| 233 |
+
|
| 234 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
| 5 |
+
|
| 6 |
+
#include <arm_fp16.h>
|
| 7 |
+
namespace mlx::core {
|
| 8 |
+
typedef __fp16 float16_t;
|
| 9 |
+
} // namespace mlx::core
|
| 10 |
+
|
| 11 |
+
#else
|
| 12 |
+
|
| 13 |
+
#define ADD_HALF_BINOPS
|
| 14 |
+
#include "mlx/types/fp16.h"
|
| 15 |
+
namespace mlx::core {
|
| 16 |
+
typedef struct _MLX_Float16 float16_t;
|
| 17 |
+
} // namespace mlx::core
|
| 18 |
+
|
| 19 |
+
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
| 20 |
+
#ifdef __ARM_FEATURE_BF16
|
| 21 |
+
|
| 22 |
+
#include <arm_bf16.h>
|
| 23 |
+
namespace mlx::core {
|
| 24 |
+
typedef __bf16 bfloat16_t;
|
| 25 |
+
} // namespace mlx::core
|
| 26 |
+
|
| 27 |
+
#else
|
| 28 |
+
|
| 29 |
+
#define ADD_HALF_BINOPS
|
| 30 |
+
#include "mlx/types/bf16.h"
|
| 31 |
+
namespace mlx::core {
|
| 32 |
+
typedef struct _MLX_BFloat16 bfloat16_t;
|
| 33 |
+
} // namespace mlx::core
|
| 34 |
+
|
| 35 |
+
#endif // __ARM_FEATURE_BF16
|
| 36 |
+
|
| 37 |
+
#ifdef ADD_HALF_BINOPS
|
| 38 |
+
namespace mlx::core {
|
| 39 |
+
|
| 40 |
+
// clang-format off
|
| 41 |
+
#define fp16_bf16_binop_helper(__op__, __operator__) \
|
| 42 |
+
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
|
| 43 |
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
| 44 |
+
} \
|
| 45 |
+
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
|
| 46 |
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
fp16_bf16_binop_helper(+, operator+)
|
| 50 |
+
fp16_bf16_binop_helper(-, operator-)
|
| 51 |
+
fp16_bf16_binop_helper(*, operator*)
|
| 52 |
+
fp16_bf16_binop_helper(/, operator/)
|
| 53 |
+
// clang-format on
|
| 54 |
+
|
| 55 |
+
} // namespace mlx::core
|
| 56 |
+
#endif
|
lib/python3.11/site-packages/mlx/include/mlx/utils.h
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "array.h"
|
| 6 |
+
#include "device.h"
|
| 7 |
+
#include "dtype.h"
|
| 8 |
+
#include "stream.h"
|
| 9 |
+
|
| 10 |
+
namespace mlx::core {
|
| 11 |
+
|
| 12 |
+
/** The type from promoting the arrays' types with one another. */
|
| 13 |
+
Dtype result_type(const std::vector<array>& arrays);
|
| 14 |
+
|
| 15 |
+
std::vector<int> broadcast_shapes(
|
| 16 |
+
const std::vector<int>& s1,
|
| 17 |
+
const std::vector<int>& s2);
|
| 18 |
+
|
| 19 |
+
bool is_same_shape(const std::vector<array>& arrays);
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Returns the axis normalized to be in the range [0, ndim).
|
| 23 |
+
* Based on numpy's normalize_axis_index. See
|
| 24 |
+
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
| 25 |
+
*/
|
| 26 |
+
int normalize_axis(int axis, int ndim);
|
| 27 |
+
|
| 28 |
+
std::ostream& operator<<(std::ostream& os, const Device& d);
|
| 29 |
+
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
| 30 |
+
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
| 31 |
+
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
| 32 |
+
std::ostream& operator<<(std::ostream& os, array a);
|
| 33 |
+
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
| 34 |
+
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
| 35 |
+
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
| 36 |
+
return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j";
|
| 37 |
+
}
|
| 38 |
+
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
| 39 |
+
return os << static_cast<float>(v);
|
| 40 |
+
}
|
| 41 |
+
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
| 42 |
+
return os << static_cast<float>(v);
|
| 43 |
+
}
|
| 44 |
+
} // namespace mlx::core
|
lib/python3.11/site-packages/mlx/lib/libmlx.dylib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8abefe46a1f39c92b28464814f05a730fa9899b17757703403c6ef362f06ac93
|
| 3 |
+
size 12420704
|
lib/python3.11/site-packages/mlx/lib/mlx.metallib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2eedf41000ed270283da11d889bb101aa4c88c6f8f0ec68fe6b040a5be424501
|
| 3 |
+
size 59495531
|
lib/python3.11/site-packages/mlx/nn/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
from mlx.nn import losses
|
| 4 |
+
from mlx.nn.layers import *
|
| 5 |
+
from mlx.nn.utils import value_and_grad
|
lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (381 Bytes). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
from mlx.nn.layers.activations import (
|
| 4 |
+
CELU,
|
| 5 |
+
ELU,
|
| 6 |
+
GELU,
|
| 7 |
+
SELU,
|
| 8 |
+
Hardswish,
|
| 9 |
+
LeakyReLU,
|
| 10 |
+
LogSigmoid,
|
| 11 |
+
LogSoftmax,
|
| 12 |
+
Mish,
|
| 13 |
+
PReLU,
|
| 14 |
+
ReLU,
|
| 15 |
+
ReLU6,
|
| 16 |
+
SiLU,
|
| 17 |
+
Softmax,
|
| 18 |
+
Softplus,
|
| 19 |
+
Softsign,
|
| 20 |
+
Step,
|
| 21 |
+
Tanh,
|
| 22 |
+
celu,
|
| 23 |
+
elu,
|
| 24 |
+
gelu,
|
| 25 |
+
gelu_approx,
|
| 26 |
+
gelu_fast_approx,
|
| 27 |
+
hardswish,
|
| 28 |
+
leaky_relu,
|
| 29 |
+
log_sigmoid,
|
| 30 |
+
log_softmax,
|
| 31 |
+
mish,
|
| 32 |
+
prelu,
|
| 33 |
+
relu,
|
| 34 |
+
relu6,
|
| 35 |
+
selu,
|
| 36 |
+
silu,
|
| 37 |
+
softmax,
|
| 38 |
+
softplus,
|
| 39 |
+
softsign,
|
| 40 |
+
step,
|
| 41 |
+
tanh,
|
| 42 |
+
)
|
| 43 |
+
from mlx.nn.layers.base import Module
|
| 44 |
+
from mlx.nn.layers.containers import Sequential
|
| 45 |
+
from mlx.nn.layers.convolution import Conv1d, Conv2d
|
| 46 |
+
from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d
|
| 47 |
+
from mlx.nn.layers.embedding import Embedding
|
| 48 |
+
from mlx.nn.layers.linear import Bilinear, Identity, Linear
|
| 49 |
+
from mlx.nn.layers.normalization import (
|
| 50 |
+
BatchNorm,
|
| 51 |
+
GroupNorm,
|
| 52 |
+
InstanceNorm,
|
| 53 |
+
LayerNorm,
|
| 54 |
+
RMSNorm,
|
| 55 |
+
)
|
| 56 |
+
from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding
|
| 57 |
+
from mlx.nn.layers.quantized import QuantizedLinear
|
| 58 |
+
from mlx.nn.layers.transformer import (
|
| 59 |
+
MultiHeadAttention,
|
| 60 |
+
Transformer,
|
| 61 |
+
TransformerEncoder,
|
| 62 |
+
TransformerEncoderLayer,
|
| 63 |
+
)
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.52 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc
ADDED
|
Binary file (6.36 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (6.92 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc
ADDED
|
Binary file (18 kB). View file
|
|
|