Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- gtm/bin/convert-caffe2-to-onnx +8 -0
- gtm/bin/convert-onnx-to-caffe2 +8 -0
- gtm/bin/isympy +8 -0
- gtm/bin/torchrun +8 -0
- gtm/lib/python3.12/site-packages/__pycache__/isympy.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_C.cpython-312-darwin.so +0 -0
- gtm/lib/python3.12/site-packages/functorch/__init__.py +38 -0
- gtm/lib/python3.12/site-packages/functorch/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/__init__.py +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__init__.py +8 -0
- gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__init__.py +7 -0
- gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__init__.py +4 -0
- gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/_src/vmap/__init__.py +16 -0
- gtm/lib/python3.12/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/compile/__init__.py +31 -0
- gtm/lib/python3.12/site-packages/functorch/compile/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__init__.py +179 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/dim.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/magic_trace.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/op_properties.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/reference.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/tree_map.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/wrap_type.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/dim/batch_tensor.py +25 -0
- gtm/lib/python3.12/site-packages/functorch/dim/delayed_mul_tensor.py +77 -0
- gtm/lib/python3.12/site-packages/functorch/dim/dim.py +110 -0
- gtm/lib/python3.12/site-packages/functorch/dim/magic_trace.py +42 -0
- gtm/lib/python3.12/site-packages/functorch/dim/op_properties.py +311 -0
- gtm/lib/python3.12/site-packages/functorch/dim/reference.py +645 -0
- gtm/lib/python3.12/site-packages/functorch/dim/tree_map.py +14 -0
- gtm/lib/python3.12/site-packages/functorch/dim/wrap_type.py +71 -0
- gtm/lib/python3.12/site-packages/functorch/einops/__init__.py +3 -0
- gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/_parsing.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/rearrange.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/einops/_parsing.py +302 -0
- gtm/lib/python3.12/site-packages/functorch/einops/rearrange.py +207 -0
- gtm/lib/python3.12/site-packages/functorch/experimental/__init__.py +6 -0
- gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/__init__.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/control_flow.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/ops.cpython-312.pyc +0 -0
- gtm/lib/python3.12/site-packages/functorch/experimental/control_flow.py +8 -0
.gitattributes
CHANGED
@@ -55,3 +55,12 @@ gtm/lib/python3.12/site-packages/pandas/_libs/join.cpython-312-darwin.so filter=
|
|
55 |
gtm/lib/python3.12/site-packages/pandas/_libs/tslibs/offsets.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
56 |
gtm/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
57 |
gtm/lib/python3.12/site-packages/safetensors/_safetensors_rust.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
gtm/lib/python3.12/site-packages/pandas/_libs/tslibs/offsets.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
56 |
gtm/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
57 |
gtm/lib/python3.12/site-packages/safetensors/_safetensors_rust.cpython-312-darwin.so filter=lfs diff=lfs merge=lfs -text
|
58 |
+
gtm/lib/python3.12/site-packages/sympy/polys/benchmarks/__pycache__/bench_solvers.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
59 |
+
gtm/lib/python3.12/site-packages/torch/.dylibs/libiomp5.dylib filter=lfs diff=lfs merge=lfs -text
|
60 |
+
gtm/lib/python3.12/site-packages/torch/bin/protoc filter=lfs diff=lfs merge=lfs -text
|
61 |
+
gtm/lib/python3.12/site-packages/torch/bin/protoc-3.13.0.0 filter=lfs diff=lfs merge=lfs -text
|
62 |
+
gtm/lib/python3.12/site-packages/torch/lib/libiomp5.dylib filter=lfs diff=lfs merge=lfs -text
|
63 |
+
gtm/lib/python3.12/site-packages/torch/lib/libtorch_cpu.dylib filter=lfs diff=lfs merge=lfs -text
|
64 |
+
gtm/lib/python3.12/site-packages/torch/lib/libtorch_python.dylib filter=lfs diff=lfs merge=lfs -text
|
65 |
+
gtm/lib/python3.12/site-packages/torchvision/.dylibs/libc++.1.0.dylib filter=lfs diff=lfs merge=lfs -text
|
66 |
+
gtm/lib/python3.12/site-packages/torchvision/_C.so filter=lfs diff=lfs merge=lfs -text
|
gtm/bin/convert-caffe2-to-onnx
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from caffe2.python.onnx.bin.conversion import caffe2_to_onnx
|
6 |
+
if __name__ == '__main__':
|
7 |
+
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
8 |
+
sys.exit(caffe2_to_onnx())
|
gtm/bin/convert-onnx-to-caffe2
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from caffe2.python.onnx.bin.conversion import onnx_to_caffe2
|
6 |
+
if __name__ == '__main__':
|
7 |
+
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
8 |
+
sys.exit(onnx_to_caffe2())
|
gtm/bin/isympy
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from isympy import main
|
6 |
+
if __name__ == '__main__':
|
7 |
+
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
8 |
+
sys.exit(main())
|
gtm/bin/torchrun
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/Users/gorgigeorgievski/code/ai/gtmio/gtm/bin/python3.12
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
from torch.distributed.run import main
|
6 |
+
if __name__ == '__main__':
|
7 |
+
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
8 |
+
sys.exit(main())
|
gtm/lib/python3.12/site-packages/__pycache__/isympy.cpython-312.pyc
ADDED
Binary file (11 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_C.cpython-312-darwin.so
ADDED
Binary file (150 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/__init__.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch._functorch.deprecated import (
|
9 |
+
combine_state_for_ensemble,
|
10 |
+
functionalize,
|
11 |
+
grad,
|
12 |
+
grad_and_value,
|
13 |
+
hessian,
|
14 |
+
jacfwd,
|
15 |
+
jacrev,
|
16 |
+
jvp,
|
17 |
+
make_functional,
|
18 |
+
make_functional_with_buffers,
|
19 |
+
vjp,
|
20 |
+
vmap,
|
21 |
+
)
|
22 |
+
|
23 |
+
# utilities. Maybe these should go in their own namespace in the future?
|
24 |
+
from torch._functorch.make_functional import (
|
25 |
+
FunctionalModule,
|
26 |
+
FunctionalModuleWithBuffers,
|
27 |
+
)
|
28 |
+
|
29 |
+
# Top-level APIs. Please think carefully before adding something to the
|
30 |
+
# top-level namespace:
|
31 |
+
# - private helper functions should go into torch._functorch
|
32 |
+
# - very experimental things should go into functorch.experimental
|
33 |
+
# - compilation related things should go into functorch.compile
|
34 |
+
|
35 |
+
# Was never documented
|
36 |
+
from torch._functorch.python_key import make_fx
|
37 |
+
|
38 |
+
__version__ = torch.__version__
|
gtm/lib/python3.12/site-packages/functorch/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (755 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_src/__init__.py
ADDED
File without changes
|
gtm/lib/python3.12/site-packages/functorch/_src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (194 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file has moved to under torch/_functorch. It is not public API.
|
2 |
+
# If you are not a PyTorch developer and you are relying on the following
|
3 |
+
# imports, please file an issue.
|
4 |
+
from torch._functorch.aot_autograd import (
|
5 |
+
aot_autograd_decompositions,
|
6 |
+
KNOWN_TYPES,
|
7 |
+
PytreeThunk,
|
8 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/_src/aot_autograd/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (342 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file has moved to under torch/_functorch. It is not public API.
|
2 |
+
# If you are not a PyTorch developer and you are relying on the following
|
3 |
+
# imports, please file an issue.
|
4 |
+
from torch._functorch.eager_transforms import (
|
5 |
+
_assert_wrapped_functional,
|
6 |
+
_unwrap_functional_tensor,
|
7 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/_src/eager_transforms/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (341 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file has moved to under torch/_functorch. It is not public API.
|
2 |
+
# If you are not a PyTorch developer and you are relying on the following
|
3 |
+
# imports, please file an issue.
|
4 |
+
from torch._functorch.make_functional import _swap_state
|
gtm/lib/python3.12/site-packages/functorch/_src/make_functional/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (283 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/_src/vmap/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file has moved to under torch/_functorch. It is not public API.
|
2 |
+
# If you are not a PyTorch developer and you are relying on the following
|
3 |
+
# imports, please file an issue.
|
4 |
+
from torch._functorch.vmap import (
|
5 |
+
_add_batch_dim,
|
6 |
+
_broadcast_to_and_flatten,
|
7 |
+
_create_batched_inputs,
|
8 |
+
_get_name,
|
9 |
+
_process_batched_inputs,
|
10 |
+
_remove_batch_dim,
|
11 |
+
_unwrap_batched,
|
12 |
+
_validate_and_get_batch_size,
|
13 |
+
Tensor,
|
14 |
+
tree_flatten,
|
15 |
+
tree_unflatten,
|
16 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/_src/vmap/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (560 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/compile/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch._functorch import config
|
2 |
+
from torch._functorch.aot_autograd import (
|
3 |
+
aot_function,
|
4 |
+
aot_module,
|
5 |
+
aot_module_simplified,
|
6 |
+
compiled_function,
|
7 |
+
compiled_module,
|
8 |
+
get_aot_compilation_context,
|
9 |
+
get_aot_graph_name,
|
10 |
+
get_graph_being_compiled,
|
11 |
+
make_boxed_compiler,
|
12 |
+
make_boxed_func,
|
13 |
+
)
|
14 |
+
from torch._functorch.compilers import (
|
15 |
+
debug_compile,
|
16 |
+
default_decompositions,
|
17 |
+
draw_graph_compile,
|
18 |
+
memory_efficient_fusion,
|
19 |
+
nnc_jit,
|
20 |
+
nop,
|
21 |
+
print_compile,
|
22 |
+
ts_compile,
|
23 |
+
)
|
24 |
+
from torch._functorch.fx_minifier import minifier
|
25 |
+
from torch._functorch.partitioners import (
|
26 |
+
default_partition,
|
27 |
+
draw_graph,
|
28 |
+
draw_joint_graph,
|
29 |
+
min_cut_rematerialization_partition,
|
30 |
+
)
|
31 |
+
from torch._functorch.python_key import pythonkey_decompose
|
gtm/lib/python3.12/site-packages/functorch/compile/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.15 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__init__.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dis
|
2 |
+
import inspect
|
3 |
+
from typing import Sequence, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import functorch._C
|
8 |
+
from functorch._C import dim as _C
|
9 |
+
from .tree_map import tree_flatten, tree_map
|
10 |
+
from .wrap_type import wrap_type
|
11 |
+
|
12 |
+
_C._patch_tensor_class()
|
13 |
+
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
|
14 |
+
|
15 |
+
|
16 |
+
class DimensionMismatchError(Exception):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
class DimensionBindError(Exception):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
from . import op_properties
|
25 |
+
|
26 |
+
# use dict to avoid writing C++ bindings for set
|
27 |
+
pointwise = {t: True for t in op_properties.pointwise}
|
28 |
+
|
29 |
+
use_c = True
|
30 |
+
if not use_c:
|
31 |
+
from . import reference
|
32 |
+
|
33 |
+
|
34 |
+
class _Tensor:
|
35 |
+
# fast path around slow wrapping/unwrapping logic for simply queries used
|
36 |
+
# by the implementation...
|
37 |
+
|
38 |
+
@property
|
39 |
+
def dims(self):
|
40 |
+
return tuple(d for d in self._levels if isinstance(d, Dim))
|
41 |
+
|
42 |
+
def dim(self):
|
43 |
+
return self.ndim
|
44 |
+
|
45 |
+
if use_c:
|
46 |
+
__torch_function__ = classmethod(_C.__torch_function__)
|
47 |
+
expand = _C._instancemethod(_C.expand)
|
48 |
+
else:
|
49 |
+
__torch_function__ = reference.__torch_function__
|
50 |
+
expand = reference.expand
|
51 |
+
|
52 |
+
index = _C._instancemethod(_C.index)
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
tensor, levels, ndim = self._tensor, self._levels, self.ndim
|
56 |
+
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
|
57 |
+
|
58 |
+
|
59 |
+
TensorLike = (_Tensor, torch.Tensor)
|
60 |
+
|
61 |
+
|
62 |
+
class Dim(_C.Dim, _Tensor):
|
63 |
+
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
|
64 |
+
# Tensor defines format, but we want to print Dims with special formatting
|
65 |
+
__format__ = object.__format__
|
66 |
+
|
67 |
+
|
68 |
+
class Tensor(_Tensor, _C.Tensor):
|
69 |
+
if not use_c:
|
70 |
+
from_batched = staticmethod(_C.Tensor_from_batched)
|
71 |
+
from_positional = staticmethod(_C.Tensor_from_positional)
|
72 |
+
sum = _C._instancemethod(_C.Tensor_sum)
|
73 |
+
|
74 |
+
|
75 |
+
def cat(tensors, dim, new_dim):
|
76 |
+
n = dims()
|
77 |
+
return stack(tensors, n, dim).index([n, dim], new_dim)
|
78 |
+
|
79 |
+
|
80 |
+
if use_c:
|
81 |
+
_wrap = _C._wrap
|
82 |
+
|
83 |
+
def _def(name, *args, **kwargs):
|
84 |
+
orig = getattr(torch.Tensor, name)
|
85 |
+
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
|
86 |
+
|
87 |
+
t__getitem__ = _C._instancemethod(_C.__getitem__)
|
88 |
+
stack = _C.stack
|
89 |
+
split = _C._instancemethod(_C.split)
|
90 |
+
else:
|
91 |
+
_wrap, _def = reference._wrap, reference._def
|
92 |
+
t__getitem__ = reference.t__getitem__
|
93 |
+
stack = reference.stack
|
94 |
+
split = reference.split
|
95 |
+
|
96 |
+
# note: there is no python reference
|
97 |
+
t__setitem__ = _C._instancemethod(_C.__setitem__)
|
98 |
+
# this is patched in the C API because otherwise torch.Tensor will
|
99 |
+
# no longer be considered a sequence and things will break
|
100 |
+
# torch.Tensor.__getitem__ = t__getitem__
|
101 |
+
|
102 |
+
_Tensor.__getitem__ = t__getitem__
|
103 |
+
# torch.Tensor.__setitem__ = t__setitem__
|
104 |
+
_Tensor.__setitem__ = t__setitem__
|
105 |
+
|
106 |
+
torch.Tensor.split = split
|
107 |
+
_Tensor.split = split
|
108 |
+
torch.Tensor.expand = _C._instancemethod(_C.expand)
|
109 |
+
torch.Tensor.index = _C._instancemethod(_C.index)
|
110 |
+
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
|
111 |
+
del _Tensor.ndim
|
112 |
+
|
113 |
+
if use_c:
|
114 |
+
_Tensor.order = _C._instancemethod(_C.order)
|
115 |
+
else:
|
116 |
+
_Tensor.order = reference.positional
|
117 |
+
|
118 |
+
_def("mean")
|
119 |
+
_def("sum")
|
120 |
+
_def("all")
|
121 |
+
_def("amax")
|
122 |
+
_def("amin")
|
123 |
+
_def("aminmax")
|
124 |
+
_def("any")
|
125 |
+
_def("count_nonzero")
|
126 |
+
_def("logsumexp")
|
127 |
+
_def("nanmean")
|
128 |
+
_def("nansum")
|
129 |
+
_def("prod")
|
130 |
+
_def("std", keepdim_offset=2)
|
131 |
+
_def("var", keepdim_offset=2)
|
132 |
+
_def("max", single_dim=True)
|
133 |
+
_def("min", single_dim=True)
|
134 |
+
_def("argmax", single_dim=True)
|
135 |
+
_def("argmin", single_dim=True)
|
136 |
+
_def("kthvalue", single_dim=True)
|
137 |
+
_def("median", single_dim=True)
|
138 |
+
_def("nanmedian", single_dim=True)
|
139 |
+
_def("mode", single_dim=True)
|
140 |
+
_def("sort", reduce=False)
|
141 |
+
_def("argsort", reduce=False)
|
142 |
+
_def("unbind", single_dim=True)
|
143 |
+
_def("chunk", dim_offset=1, reduce=False)
|
144 |
+
_def("cummax", single_dim=True, reduce=False)
|
145 |
+
_def("cummin", single_dim=True, reduce=False)
|
146 |
+
_def("cumprod", single_dim=True, reduce=False)
|
147 |
+
_def("cumprod_", single_dim=True, reduce=False)
|
148 |
+
_def("cumsum", single_dim=True, reduce=False)
|
149 |
+
_def("cumsum_", single_dim=True, reduce=False)
|
150 |
+
_def("logcumsumexp", single_dim=True, reduce=False)
|
151 |
+
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
|
152 |
+
_def("softmax", single_dim=True, reduce=False)
|
153 |
+
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
|
154 |
+
|
155 |
+
# stuff to handle in the future, because they require special
|
156 |
+
# binding logic for dims
|
157 |
+
# cross
|
158 |
+
# diag_embed
|
159 |
+
# diagonal
|
160 |
+
# diagonal_scatter
|
161 |
+
# diff
|
162 |
+
# nanquantile
|
163 |
+
# quantile
|
164 |
+
# roll
|
165 |
+
# rot90
|
166 |
+
# topk (new dimes on output)
|
167 |
+
# should these all be subsumed by inplace indexing?
|
168 |
+
# index_add_
|
169 |
+
# index_add
|
170 |
+
# index_copy
|
171 |
+
# index_copy_
|
172 |
+
# index_fill
|
173 |
+
# index_fill_
|
174 |
+
# index_select
|
175 |
+
# scatter
|
176 |
+
# scatter_
|
177 |
+
# scatter_add
|
178 |
+
# scatter_add_
|
179 |
+
# scatter_reduce
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (7.16 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/batch_tensor.cpython-312.pyc
ADDED
Binary file (1.13 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/delayed_mul_tensor.cpython-312.pyc
ADDED
Binary file (5.27 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/dim.cpython-312.pyc
ADDED
Binary file (6.18 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/magic_trace.cpython-312.pyc
ADDED
Binary file (2.26 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/op_properties.cpython-312.pyc
ADDED
Binary file (17 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/reference.cpython-312.pyc
ADDED
Binary file (27.8 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/tree_map.cpython-312.pyc
ADDED
Binary file (695 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/__pycache__/wrap_type.cpython-312.pyc
ADDED
Binary file (2.15 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/dim/batch_tensor.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
from contextlib import contextmanager
|
7 |
+
|
8 |
+
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
|
9 |
+
|
10 |
+
_enabled = False
|
11 |
+
|
12 |
+
|
13 |
+
@contextmanager
|
14 |
+
def _enable_layers(dims):
|
15 |
+
global _enabled
|
16 |
+
assert not _enabled
|
17 |
+
input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
|
18 |
+
n = len(input)
|
19 |
+
try:
|
20 |
+
_vmap_add_layers(input)
|
21 |
+
_enabled = True
|
22 |
+
yield
|
23 |
+
finally:
|
24 |
+
_enabled = False
|
25 |
+
_vmap_remove_layers(n)
|
gtm/lib/python3.12/site-packages/functorch/dim/delayed_mul_tensor.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from . import _Tensor, Tensor
|
9 |
+
from .reference import _dims, _enable_layers, llist, ltuple
|
10 |
+
|
11 |
+
|
12 |
+
class DelayedMulTensor(_Tensor):
|
13 |
+
def __init__(self, lhs, rhs):
|
14 |
+
self._lhs, self._rhs = lhs, rhs
|
15 |
+
self._data = None
|
16 |
+
self._levels_data = None
|
17 |
+
self._has_device = lhs._has_device or rhs._has_device
|
18 |
+
self._batchtensor_data = None
|
19 |
+
self._tensor_data = None
|
20 |
+
|
21 |
+
@property
|
22 |
+
def _levels(self):
|
23 |
+
if self._levels_data is None:
|
24 |
+
levels = llist(self._lhs._levels)
|
25 |
+
for l in self._rhs._levels:
|
26 |
+
if l not in levels:
|
27 |
+
levels.append(l)
|
28 |
+
self._levels_data = ltuple(levels)
|
29 |
+
return self._levels_data
|
30 |
+
|
31 |
+
@property
|
32 |
+
def _batchtensor(self):
|
33 |
+
if self._batchtensor_data is None:
|
34 |
+
with _enable_layers(self._levels):
|
35 |
+
print("bt multiply fallback")
|
36 |
+
self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
|
37 |
+
return self._batchtensor_data
|
38 |
+
|
39 |
+
@property
|
40 |
+
def _tensor(self):
|
41 |
+
if self._tensor_data is None:
|
42 |
+
self._tensor_data = Tensor.from_batched(
|
43 |
+
self._batchtensor, self._has_device
|
44 |
+
)._tensor
|
45 |
+
return self._tensor_data
|
46 |
+
|
47 |
+
@property
|
48 |
+
def ndim(self):
|
49 |
+
return self._batchtensor.ndim
|
50 |
+
|
51 |
+
@property
|
52 |
+
def dims(self):
|
53 |
+
return ltuple(super().dims)
|
54 |
+
|
55 |
+
def sum(self, dim):
|
56 |
+
dims = _dims(dim, 0, False, False)
|
57 |
+
n = ord("a")
|
58 |
+
all_levels = self._levels
|
59 |
+
|
60 |
+
def to_char(d):
|
61 |
+
return chr(n + all_levels.index(d))
|
62 |
+
|
63 |
+
plhs, levelslhs = self._lhs._tensor, self._lhs._levels
|
64 |
+
prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
|
65 |
+
new_dims = tuple(d for d in self.dims if d not in dims)
|
66 |
+
new_levels = [l for l in self._levels if l not in dims]
|
67 |
+
fmt = "".join(
|
68 |
+
[
|
69 |
+
*(to_char(d) for d in levelslhs),
|
70 |
+
",",
|
71 |
+
*(to_char(d) for d in levelsrhs),
|
72 |
+
"->",
|
73 |
+
*(to_char(d) for d in new_levels),
|
74 |
+
]
|
75 |
+
)
|
76 |
+
result_data = torch.einsum(fmt, (plhs, prhs))
|
77 |
+
return Tensor.from_positional(result_data, new_levels, True)
|
gtm/lib/python3.12/site-packages/functorch/dim/dim.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
_vmap_levels = []
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class LevelInfo:
|
11 |
+
level: int
|
12 |
+
alive: bool = True
|
13 |
+
|
14 |
+
|
15 |
+
class Dim:
|
16 |
+
def __init__(self, name: str, size: Union[None, int] = None):
|
17 |
+
self.name = name
|
18 |
+
self._size = None
|
19 |
+
self._vmap_level = None
|
20 |
+
if size is not None:
|
21 |
+
self.size = size
|
22 |
+
|
23 |
+
def __del__(self):
|
24 |
+
if self._vmap_level is not None:
|
25 |
+
_vmap_active_levels[self._vmap_stack].alive = False
|
26 |
+
while (
|
27 |
+
not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level
|
28 |
+
):
|
29 |
+
_vmap_decrement_nesting()
|
30 |
+
_vmap_levels.pop()
|
31 |
+
|
32 |
+
@property
|
33 |
+
def size(self):
|
34 |
+
assert self.is_bound
|
35 |
+
return self._size
|
36 |
+
|
37 |
+
@size.setter
|
38 |
+
def size(self, size: int):
|
39 |
+
if self._size is None:
|
40 |
+
self._size = size
|
41 |
+
self._vmap_level = _vmap_increment_nesting(size, "same")
|
42 |
+
self._vmap_stack = len(_vmap_levels)
|
43 |
+
_vmap_levels.append(LevelInfo(self._vmap_level))
|
44 |
+
|
45 |
+
elif self._size != size:
|
46 |
+
raise DimensionBindError(
|
47 |
+
f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
|
48 |
+
)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def is_bound(self):
|
52 |
+
return self._size is not None
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
return self.name
|
56 |
+
|
57 |
+
|
58 |
+
def extract_name(inst):
|
59 |
+
assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
|
60 |
+
return inst.argval
|
61 |
+
|
62 |
+
|
63 |
+
_cache = {}
|
64 |
+
|
65 |
+
|
66 |
+
def dims(lists=0):
|
67 |
+
frame = inspect.currentframe()
|
68 |
+
assert frame is not None
|
69 |
+
calling_frame = frame.f_back
|
70 |
+
assert calling_frame is not None
|
71 |
+
code, lasti = calling_frame.f_code, calling_frame.f_lasti
|
72 |
+
key = (code, lasti)
|
73 |
+
if key not in _cache:
|
74 |
+
first = lasti // 2 + 1
|
75 |
+
instructions = list(dis.get_instructions(calling_frame.f_code))
|
76 |
+
unpack = instructions[first]
|
77 |
+
|
78 |
+
if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
|
79 |
+
# just a single dim, not a list
|
80 |
+
name = unpack.argval
|
81 |
+
ctor = Dim if lists == 0 else DimList
|
82 |
+
_cache[key] = lambda: ctor(name=name)
|
83 |
+
else:
|
84 |
+
assert unpack.opname == "UNPACK_SEQUENCE"
|
85 |
+
ndims = unpack.argval
|
86 |
+
names = tuple(
|
87 |
+
extract_name(instructions[first + 1 + i]) for i in range(ndims)
|
88 |
+
)
|
89 |
+
first_list = len(names) - lists
|
90 |
+
_cache[key] = lambda: tuple(
|
91 |
+
Dim(n) if i < first_list else DimList(name=n)
|
92 |
+
for i, n in enumerate(names)
|
93 |
+
)
|
94 |
+
return _cache[key]()
|
95 |
+
|
96 |
+
|
97 |
+
def _dim_set(positional, arg):
|
98 |
+
def convert(a):
|
99 |
+
if isinstance(a, Dim):
|
100 |
+
return a
|
101 |
+
else:
|
102 |
+
assert isinstance(a, int)
|
103 |
+
return positional[a]
|
104 |
+
|
105 |
+
if arg is None:
|
106 |
+
return positional
|
107 |
+
elif not isinstance(arg, (Dim, int)):
|
108 |
+
return tuple(convert(a) for a in arg)
|
109 |
+
else:
|
110 |
+
return (convert(arg),)
|
gtm/lib/python3.12/site-packages/functorch/dim/magic_trace.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
import subprocess
|
9 |
+
from contextlib import contextmanager
|
10 |
+
|
11 |
+
|
12 |
+
@contextmanager
|
13 |
+
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
14 |
+
pid = os.getpid()
|
15 |
+
if not os.path.exists(magic_trace_cache):
|
16 |
+
print(f"Downloading magic_trace to: {magic_trace_cache}")
|
17 |
+
subprocess.run(
|
18 |
+
[
|
19 |
+
"wget",
|
20 |
+
"-O",
|
21 |
+
magic_trace_cache,
|
22 |
+
"-q",
|
23 |
+
"https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
|
24 |
+
]
|
25 |
+
)
|
26 |
+
subprocess.run(["chmod", "+x", magic_trace_cache])
|
27 |
+
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
|
28 |
+
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
|
29 |
+
while True:
|
30 |
+
x = p.stderr.readline()
|
31 |
+
print(x)
|
32 |
+
if "Attached" in x:
|
33 |
+
break
|
34 |
+
try:
|
35 |
+
yield
|
36 |
+
finally:
|
37 |
+
p.send_signal(signal.SIGINT)
|
38 |
+
r = p.wait()
|
39 |
+
print(p.stderr.read())
|
40 |
+
p.stderr.close()
|
41 |
+
if r != 0:
|
42 |
+
raise ValueError(f"magic_trace exited abnormally: {r}")
|
gtm/lib/python3.12/site-packages/functorch/dim/op_properties.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import torch
|
7 |
+
|
8 |
+
# pointwise operators can go through a faster pathway
|
9 |
+
|
10 |
+
tensor_magic_methods = ["add", ""]
|
11 |
+
pointwise_magic_methods_with_reverse = (
|
12 |
+
"add",
|
13 |
+
"sub",
|
14 |
+
"mul",
|
15 |
+
"floordiv",
|
16 |
+
"div",
|
17 |
+
"truediv",
|
18 |
+
"mod",
|
19 |
+
"pow",
|
20 |
+
"lshift",
|
21 |
+
"rshift",
|
22 |
+
"and",
|
23 |
+
"or",
|
24 |
+
"xor",
|
25 |
+
)
|
26 |
+
pointwise_magic_methods = (
|
27 |
+
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)),
|
28 |
+
"eq",
|
29 |
+
"gt",
|
30 |
+
"le",
|
31 |
+
"lt",
|
32 |
+
"ge",
|
33 |
+
"gt",
|
34 |
+
"ne",
|
35 |
+
"neg",
|
36 |
+
"pos",
|
37 |
+
"abs",
|
38 |
+
"invert",
|
39 |
+
"iadd",
|
40 |
+
"isub",
|
41 |
+
"imul",
|
42 |
+
"ifloordiv",
|
43 |
+
"idiv",
|
44 |
+
"itruediv",
|
45 |
+
"imod",
|
46 |
+
"ipow",
|
47 |
+
"ilshift",
|
48 |
+
"irshift",
|
49 |
+
"iand",
|
50 |
+
"ior",
|
51 |
+
"ixor",
|
52 |
+
"int",
|
53 |
+
"long",
|
54 |
+
"float",
|
55 |
+
"complex",
|
56 |
+
)
|
57 |
+
|
58 |
+
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),)
|
59 |
+
|
60 |
+
pointwise = (
|
61 |
+
*(getattr(torch.Tensor, m) for m in pointwise_methods),
|
62 |
+
torch.nn.functional.dropout,
|
63 |
+
torch.where,
|
64 |
+
torch.Tensor.abs,
|
65 |
+
torch.abs,
|
66 |
+
torch.Tensor.acos,
|
67 |
+
torch.acos,
|
68 |
+
torch.Tensor.acosh,
|
69 |
+
torch.acosh,
|
70 |
+
torch.Tensor.add,
|
71 |
+
torch.add,
|
72 |
+
torch.Tensor.addcdiv,
|
73 |
+
torch.addcdiv,
|
74 |
+
torch.Tensor.addcmul,
|
75 |
+
torch.addcmul,
|
76 |
+
torch.Tensor.addr,
|
77 |
+
torch.addr,
|
78 |
+
torch.Tensor.angle,
|
79 |
+
torch.angle,
|
80 |
+
torch.Tensor.asin,
|
81 |
+
torch.asin,
|
82 |
+
torch.Tensor.asinh,
|
83 |
+
torch.asinh,
|
84 |
+
torch.Tensor.atan,
|
85 |
+
torch.atan,
|
86 |
+
torch.Tensor.atan2,
|
87 |
+
torch.atan2,
|
88 |
+
torch.Tensor.atanh,
|
89 |
+
torch.atanh,
|
90 |
+
torch.Tensor.bitwise_and,
|
91 |
+
torch.bitwise_and,
|
92 |
+
torch.Tensor.bitwise_left_shift,
|
93 |
+
torch.bitwise_left_shift,
|
94 |
+
torch.Tensor.bitwise_not,
|
95 |
+
torch.bitwise_not,
|
96 |
+
torch.Tensor.bitwise_or,
|
97 |
+
torch.bitwise_or,
|
98 |
+
torch.Tensor.bitwise_right_shift,
|
99 |
+
torch.bitwise_right_shift,
|
100 |
+
torch.Tensor.bitwise_xor,
|
101 |
+
torch.bitwise_xor,
|
102 |
+
torch.Tensor.ceil,
|
103 |
+
torch.ceil,
|
104 |
+
torch.celu,
|
105 |
+
torch.nn.functional.celu,
|
106 |
+
torch.Tensor.clamp,
|
107 |
+
torch.clamp,
|
108 |
+
torch.Tensor.clamp_max,
|
109 |
+
torch.clamp_max,
|
110 |
+
torch.Tensor.clamp_min,
|
111 |
+
torch.clamp_min,
|
112 |
+
torch.Tensor.copysign,
|
113 |
+
torch.copysign,
|
114 |
+
torch.Tensor.cos,
|
115 |
+
torch.cos,
|
116 |
+
torch.Tensor.cosh,
|
117 |
+
torch.cosh,
|
118 |
+
torch.Tensor.deg2rad,
|
119 |
+
torch.deg2rad,
|
120 |
+
torch.Tensor.digamma,
|
121 |
+
torch.digamma,
|
122 |
+
torch.Tensor.div,
|
123 |
+
torch.div,
|
124 |
+
torch.dropout,
|
125 |
+
torch.nn.functional.dropout,
|
126 |
+
torch.nn.functional.elu,
|
127 |
+
torch.Tensor.eq,
|
128 |
+
torch.eq,
|
129 |
+
torch.Tensor.erf,
|
130 |
+
torch.erf,
|
131 |
+
torch.Tensor.erfc,
|
132 |
+
torch.erfc,
|
133 |
+
torch.Tensor.erfinv,
|
134 |
+
torch.erfinv,
|
135 |
+
torch.Tensor.exp,
|
136 |
+
torch.exp,
|
137 |
+
torch.Tensor.exp2,
|
138 |
+
torch.exp2,
|
139 |
+
torch.Tensor.expm1,
|
140 |
+
torch.expm1,
|
141 |
+
torch.feature_dropout,
|
142 |
+
torch.Tensor.float_power,
|
143 |
+
torch.float_power,
|
144 |
+
torch.Tensor.floor,
|
145 |
+
torch.floor,
|
146 |
+
torch.Tensor.floor_divide,
|
147 |
+
torch.floor_divide,
|
148 |
+
torch.Tensor.fmod,
|
149 |
+
torch.fmod,
|
150 |
+
torch.Tensor.frac,
|
151 |
+
torch.frac,
|
152 |
+
torch.Tensor.frexp,
|
153 |
+
torch.frexp,
|
154 |
+
torch.Tensor.gcd,
|
155 |
+
torch.gcd,
|
156 |
+
torch.Tensor.ge,
|
157 |
+
torch.ge,
|
158 |
+
torch.nn.functional.gelu,
|
159 |
+
torch.nn.functional.glu,
|
160 |
+
torch.Tensor.gt,
|
161 |
+
torch.gt,
|
162 |
+
torch.Tensor.hardshrink,
|
163 |
+
torch.hardshrink,
|
164 |
+
torch.nn.functional.hardshrink,
|
165 |
+
torch.nn.functional.hardsigmoid,
|
166 |
+
torch.nn.functional.hardswish,
|
167 |
+
torch.nn.functional.hardtanh,
|
168 |
+
torch.Tensor.heaviside,
|
169 |
+
torch.heaviside,
|
170 |
+
torch.Tensor.hypot,
|
171 |
+
torch.hypot,
|
172 |
+
torch.Tensor.i0,
|
173 |
+
torch.i0,
|
174 |
+
torch.Tensor.igamma,
|
175 |
+
torch.igamma,
|
176 |
+
torch.Tensor.igammac,
|
177 |
+
torch.igammac,
|
178 |
+
torch.Tensor.isclose,
|
179 |
+
torch.isclose,
|
180 |
+
torch.Tensor.isfinite,
|
181 |
+
torch.isfinite,
|
182 |
+
torch.Tensor.isinf,
|
183 |
+
torch.isinf,
|
184 |
+
torch.Tensor.isnan,
|
185 |
+
torch.isnan,
|
186 |
+
torch.Tensor.isneginf,
|
187 |
+
torch.isneginf,
|
188 |
+
torch.Tensor.isposinf,
|
189 |
+
torch.isposinf,
|
190 |
+
torch.Tensor.isreal,
|
191 |
+
torch.isreal,
|
192 |
+
torch.Tensor.kron,
|
193 |
+
torch.kron,
|
194 |
+
torch.Tensor.lcm,
|
195 |
+
torch.lcm,
|
196 |
+
torch.Tensor.ldexp,
|
197 |
+
torch.ldexp,
|
198 |
+
torch.Tensor.le,
|
199 |
+
torch.le,
|
200 |
+
torch.nn.functional.leaky_relu,
|
201 |
+
torch.Tensor.lerp,
|
202 |
+
torch.lerp,
|
203 |
+
torch.Tensor.lgamma,
|
204 |
+
torch.lgamma,
|
205 |
+
torch.Tensor.log,
|
206 |
+
torch.log,
|
207 |
+
torch.Tensor.log10,
|
208 |
+
torch.log10,
|
209 |
+
torch.Tensor.log1p,
|
210 |
+
torch.log1p,
|
211 |
+
torch.Tensor.log2,
|
212 |
+
torch.log2,
|
213 |
+
torch.nn.functional.logsigmoid,
|
214 |
+
torch.Tensor.logical_and,
|
215 |
+
torch.logical_and,
|
216 |
+
torch.Tensor.logical_not,
|
217 |
+
torch.logical_not,
|
218 |
+
torch.Tensor.logical_or,
|
219 |
+
torch.logical_or,
|
220 |
+
torch.Tensor.logical_xor,
|
221 |
+
torch.logical_xor,
|
222 |
+
torch.Tensor.logit,
|
223 |
+
torch.logit,
|
224 |
+
torch.Tensor.lt,
|
225 |
+
torch.lt,
|
226 |
+
torch.Tensor.maximum,
|
227 |
+
torch.maximum,
|
228 |
+
torch.Tensor.minimum,
|
229 |
+
torch.minimum,
|
230 |
+
torch.nn.functional.mish,
|
231 |
+
torch.Tensor.mvlgamma,
|
232 |
+
torch.mvlgamma,
|
233 |
+
torch.Tensor.nan_to_num,
|
234 |
+
torch.nan_to_num,
|
235 |
+
torch.Tensor.ne,
|
236 |
+
torch.ne,
|
237 |
+
torch.Tensor.neg,
|
238 |
+
torch.neg,
|
239 |
+
torch.Tensor.nextafter,
|
240 |
+
torch.nextafter,
|
241 |
+
torch.Tensor.outer,
|
242 |
+
torch.outer,
|
243 |
+
torch.polar,
|
244 |
+
torch.Tensor.polygamma,
|
245 |
+
torch.polygamma,
|
246 |
+
torch.Tensor.positive,
|
247 |
+
torch.positive,
|
248 |
+
torch.Tensor.pow,
|
249 |
+
torch.pow,
|
250 |
+
torch.Tensor.prelu,
|
251 |
+
torch.prelu,
|
252 |
+
torch.nn.functional.prelu,
|
253 |
+
torch.Tensor.rad2deg,
|
254 |
+
torch.rad2deg,
|
255 |
+
torch.Tensor.reciprocal,
|
256 |
+
torch.reciprocal,
|
257 |
+
torch.Tensor.relu,
|
258 |
+
torch.relu,
|
259 |
+
torch.nn.functional.relu,
|
260 |
+
torch.nn.functional.relu6,
|
261 |
+
torch.Tensor.remainder,
|
262 |
+
torch.remainder,
|
263 |
+
torch.Tensor.round,
|
264 |
+
torch.round,
|
265 |
+
torch.rrelu,
|
266 |
+
torch.nn.functional.rrelu,
|
267 |
+
torch.Tensor.rsqrt,
|
268 |
+
torch.rsqrt,
|
269 |
+
torch.rsub,
|
270 |
+
torch.selu,
|
271 |
+
torch.nn.functional.selu,
|
272 |
+
torch.Tensor.sgn,
|
273 |
+
torch.sgn,
|
274 |
+
torch.Tensor.sigmoid,
|
275 |
+
torch.sigmoid,
|
276 |
+
torch.nn.functional.sigmoid,
|
277 |
+
torch.Tensor.sign,
|
278 |
+
torch.sign,
|
279 |
+
torch.Tensor.signbit,
|
280 |
+
torch.signbit,
|
281 |
+
torch.nn.functional.silu,
|
282 |
+
torch.Tensor.sin,
|
283 |
+
torch.sin,
|
284 |
+
torch.Tensor.sinc,
|
285 |
+
torch.sinc,
|
286 |
+
torch.Tensor.sinh,
|
287 |
+
torch.sinh,
|
288 |
+
torch.nn.functional.softplus,
|
289 |
+
torch.nn.functional.softshrink,
|
290 |
+
torch.Tensor.sqrt,
|
291 |
+
torch.sqrt,
|
292 |
+
torch.Tensor.square,
|
293 |
+
torch.square,
|
294 |
+
torch.Tensor.sub,
|
295 |
+
torch.sub,
|
296 |
+
torch.Tensor.tan,
|
297 |
+
torch.tan,
|
298 |
+
torch.Tensor.tanh,
|
299 |
+
torch.tanh,
|
300 |
+
torch.nn.functional.tanh,
|
301 |
+
torch.threshold,
|
302 |
+
torch.nn.functional.threshold,
|
303 |
+
torch.trapz,
|
304 |
+
torch.Tensor.true_divide,
|
305 |
+
torch.true_divide,
|
306 |
+
torch.Tensor.trunc,
|
307 |
+
torch.trunc,
|
308 |
+
torch.Tensor.xlogy,
|
309 |
+
torch.xlogy,
|
310 |
+
torch.rand_like,
|
311 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/dim/reference.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# reference python implementations for C ops
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from functorch._C import dim as _C
|
11 |
+
from . import op_properties
|
12 |
+
from .batch_tensor import _enable_layers
|
13 |
+
from .tree_map import tree_flatten, tree_map
|
14 |
+
|
15 |
+
DimList = _C.DimList
|
16 |
+
import operator
|
17 |
+
from functools import reduce
|
18 |
+
|
19 |
+
|
20 |
+
# use dict to avoid writing C++ bindings for set
|
21 |
+
pointwise = set(op_properties.pointwise)
|
22 |
+
|
23 |
+
|
24 |
+
def prod(x):
|
25 |
+
return reduce(operator.mul, x, 1)
|
26 |
+
|
27 |
+
|
28 |
+
def _wrap_dim(d, N, keepdim):
|
29 |
+
from . import Dim
|
30 |
+
|
31 |
+
if isinstance(d, Dim):
|
32 |
+
assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
|
33 |
+
return d
|
34 |
+
elif d >= 0:
|
35 |
+
return d - N
|
36 |
+
else:
|
37 |
+
return d
|
38 |
+
|
39 |
+
|
40 |
+
def _dims(d, N, keepdim, single_dim):
|
41 |
+
from . import Dim
|
42 |
+
|
43 |
+
if isinstance(d, (Dim, int)):
|
44 |
+
return ltuple((_wrap_dim(d, N, keepdim),))
|
45 |
+
assert not single_dim, f"expected a single dimension or int but found: {d}"
|
46 |
+
return ltuple(_wrap_dim(x, N, keepdim) for x in d)
|
47 |
+
|
48 |
+
|
49 |
+
def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
|
50 |
+
from . import DimensionMismatchError
|
51 |
+
|
52 |
+
not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
|
53 |
+
if len(not_bound) == 1:
|
54 |
+
idx, d = not_bound[0]
|
55 |
+
rhs_so_far = prod(r.size for r in rhs if r.is_bound)
|
56 |
+
if lhs_size % rhs_so_far != 0:
|
57 |
+
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
58 |
+
raise DimensionMismatchError(
|
59 |
+
f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
|
60 |
+
)
|
61 |
+
new_size = lhs_size // rhs_so_far
|
62 |
+
d.size = new_size
|
63 |
+
elif len(not_bound) > 1:
|
64 |
+
rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
|
65 |
+
raise DimensionMismatchError(
|
66 |
+
f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
rhs_size = prod(r.size for r in rhs)
|
70 |
+
if lhs_size != rhs_size:
|
71 |
+
raise DimensionMismatchError(
|
72 |
+
f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _tensor_levels(inp):
|
77 |
+
from . import _Tensor
|
78 |
+
|
79 |
+
if isinstance(inp, _Tensor):
|
80 |
+
return inp._tensor, llist(inp._levels), inp._has_device
|
81 |
+
else:
|
82 |
+
return inp, llist(range(-inp.ndim, 0)), True
|
83 |
+
|
84 |
+
|
85 |
+
def _match_levels(v, from_levels, to_levels):
|
86 |
+
view = []
|
87 |
+
permute = []
|
88 |
+
requires_view = False
|
89 |
+
size = v.size()
|
90 |
+
for t in to_levels:
|
91 |
+
try:
|
92 |
+
idx = from_levels.index(t)
|
93 |
+
permute.append(idx)
|
94 |
+
view.append(size[idx])
|
95 |
+
except ValueError:
|
96 |
+
view.append(1)
|
97 |
+
requires_view = True
|
98 |
+
if permute != list(range(len(permute))):
|
99 |
+
v = v.permute(*permute)
|
100 |
+
if requires_view:
|
101 |
+
v = v.view(*view)
|
102 |
+
return v
|
103 |
+
|
104 |
+
|
105 |
+
# make a single dimension positional but do not permute it,
|
106 |
+
# used to do multi-tensor operators where the dim being acted on
|
107 |
+
# should not physically move if possible
|
108 |
+
def _positional_no_permute(self, dim, expand_dim=False):
|
109 |
+
from . import Tensor
|
110 |
+
|
111 |
+
ptensor, levels = self._tensor, llist(self._levels)
|
112 |
+
try:
|
113 |
+
idx = levels.index(dim)
|
114 |
+
except ValueError:
|
115 |
+
if not expand_dim:
|
116 |
+
raise
|
117 |
+
idx = 0
|
118 |
+
ptensor = ptensor.expand(dim.size, *ptensor.size())
|
119 |
+
levels.insert(0, 0)
|
120 |
+
idx_batched = 0
|
121 |
+
for i in range(idx):
|
122 |
+
if isinstance(levels[i], int):
|
123 |
+
levels[i] -= 1
|
124 |
+
idx_batched += 1
|
125 |
+
levels[idx] = -idx_batched - 1
|
126 |
+
return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
|
127 |
+
|
128 |
+
|
129 |
+
def seq(a, b):
|
130 |
+
from . import Dim
|
131 |
+
|
132 |
+
if isinstance(a, Dim) != isinstance(b, Dim):
|
133 |
+
return False
|
134 |
+
if isinstance(a, Dim):
|
135 |
+
return a is b
|
136 |
+
else:
|
137 |
+
return a == b
|
138 |
+
|
139 |
+
|
140 |
+
class isin:
|
141 |
+
def __contains__(self, item):
|
142 |
+
for x in self:
|
143 |
+
if seq(item, x):
|
144 |
+
return True
|
145 |
+
return False
|
146 |
+
|
147 |
+
def index(self, item):
|
148 |
+
for i, x in enumerate(self):
|
149 |
+
if seq(item, x):
|
150 |
+
return i
|
151 |
+
raise ValueError
|
152 |
+
|
153 |
+
|
154 |
+
class llist(isin, list):
|
155 |
+
pass
|
156 |
+
|
157 |
+
|
158 |
+
class ltuple(isin, tuple):
|
159 |
+
pass
|
160 |
+
|
161 |
+
|
162 |
+
empty_dict = {}
|
163 |
+
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
|
167 |
+
from . import _Tensor, Tensor, TensorLike
|
168 |
+
from .delayed_mul_tensor import DelayedMulTensor
|
169 |
+
|
170 |
+
if orig is torch.Tensor.__mul__:
|
171 |
+
lhs, rhs = args
|
172 |
+
if (
|
173 |
+
isinstance(lhs, _Tensor)
|
174 |
+
and isinstance(rhs, _Tensor)
|
175 |
+
and lhs.ndim == 0
|
176 |
+
and rhs.ndim == 0
|
177 |
+
):
|
178 |
+
return DelayedMulTensor(lhs, rhs)
|
179 |
+
all_dims = llist()
|
180 |
+
flat_args, unflatten = tree_flatten((args, kwargs))
|
181 |
+
device_holding_tensor = None
|
182 |
+
for f in flat_args:
|
183 |
+
if isinstance(f, _Tensor):
|
184 |
+
if f._has_device:
|
185 |
+
device_holding_tensor = f._batchtensor
|
186 |
+
for d in f.dims:
|
187 |
+
if d not in all_dims:
|
188 |
+
all_dims.append(d)
|
189 |
+
|
190 |
+
def unwrap(t):
|
191 |
+
if isinstance(t, _Tensor):
|
192 |
+
r = t._batchtensor
|
193 |
+
if device_holding_tensor is not None and not t._has_device:
|
194 |
+
r = r.to(device=device_holding_tensor.device)
|
195 |
+
return r
|
196 |
+
return t
|
197 |
+
|
198 |
+
if orig in pointwise:
|
199 |
+
result_levels = llist()
|
200 |
+
arg_levels = llist()
|
201 |
+
to_expand = []
|
202 |
+
for i, f in enumerate(flat_args):
|
203 |
+
if isinstance(f, TensorLike):
|
204 |
+
ptensor, levels, _ = _tensor_levels(f)
|
205 |
+
if (
|
206 |
+
isinstance(f, _Tensor)
|
207 |
+
and not f._has_device
|
208 |
+
and device_holding_tensor is not None
|
209 |
+
):
|
210 |
+
ptensor = ptensor.to(device=device_holding_tensor.device)
|
211 |
+
flat_args[i] = ptensor
|
212 |
+
for l in levels:
|
213 |
+
if l not in result_levels:
|
214 |
+
result_levels.append(l)
|
215 |
+
to_expand.append((i, levels))
|
216 |
+
|
217 |
+
for i, levels in to_expand:
|
218 |
+
flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
|
219 |
+
args, kwargs = unflatten(flat_args)
|
220 |
+
result = orig(*args, **kwargs)
|
221 |
+
|
222 |
+
def wrap(t):
|
223 |
+
if isinstance(t, TensorLike):
|
224 |
+
return Tensor.from_positional(
|
225 |
+
t, result_levels, device_holding_tensor is not None
|
226 |
+
)
|
227 |
+
return t
|
228 |
+
|
229 |
+
return tree_map(wrap, result)
|
230 |
+
else:
|
231 |
+
|
232 |
+
def wrap(t):
|
233 |
+
if isinstance(t, TensorLike):
|
234 |
+
return Tensor.from_batched(t, device_holding_tensor is not None)
|
235 |
+
return t
|
236 |
+
|
237 |
+
with _enable_layers(all_dims):
|
238 |
+
print(f"batch_tensor for {orig}")
|
239 |
+
args, kwargs = unflatten(unwrap(f) for f in flat_args)
|
240 |
+
result = orig(*args, **kwargs)
|
241 |
+
# print("END", orig)
|
242 |
+
return tree_map(wrap, result)
|
243 |
+
|
244 |
+
|
245 |
+
def positional(self, *dims):
|
246 |
+
from . import Dim, Tensor
|
247 |
+
|
248 |
+
ptensor, levels = self._tensor, llist(self._levels)
|
249 |
+
flat_dims = llist()
|
250 |
+
view = []
|
251 |
+
needs_view = False
|
252 |
+
ndim = self.ndim
|
253 |
+
for d in dims:
|
254 |
+
if isinstance(d, DimList):
|
255 |
+
flat_dims.extend(d)
|
256 |
+
view.extend(e.size for e in d)
|
257 |
+
elif isinstance(d, Dim):
|
258 |
+
flat_dims.append(d)
|
259 |
+
view.append(d.size)
|
260 |
+
elif isinstance(d, int):
|
261 |
+
d = _wrap_dim(d, ndim, False)
|
262 |
+
flat_dims.append(d)
|
263 |
+
view.append(ptensor.size(d))
|
264 |
+
else:
|
265 |
+
flat_dims.extend(d)
|
266 |
+
view.append(prod(e.size for e in d))
|
267 |
+
needs_view = True
|
268 |
+
|
269 |
+
permute = list(range(len(levels)))
|
270 |
+
nflat = len(flat_dims)
|
271 |
+
for i, d in enumerate(flat_dims):
|
272 |
+
try:
|
273 |
+
idx = levels.index(d)
|
274 |
+
except ValueError as e:
|
275 |
+
raise DimensionBindError(
|
276 |
+
f"tensor of dimensions {self.dims} does not contain dim {d}"
|
277 |
+
) from e
|
278 |
+
p = permute[idx]
|
279 |
+
del levels[idx]
|
280 |
+
del permute[idx]
|
281 |
+
levels.insert(i, 0)
|
282 |
+
permute.insert(i, p)
|
283 |
+
ptensor = ptensor.permute(*permute)
|
284 |
+
seen = 0
|
285 |
+
for i in range(len(levels) - 1, -1, -1):
|
286 |
+
if isinstance(levels[i], int):
|
287 |
+
seen += 1
|
288 |
+
levels[i] = -seen
|
289 |
+
result = Tensor.from_positional(ptensor, levels, self._has_device)
|
290 |
+
if needs_view:
|
291 |
+
result = result.reshape(*view, *result.size()[len(flat_dims) :])
|
292 |
+
return result
|
293 |
+
|
294 |
+
|
295 |
+
def _contains_dim(input):
|
296 |
+
from . import Dim
|
297 |
+
|
298 |
+
for i in input:
|
299 |
+
if isinstance(i, Dim):
|
300 |
+
return True
|
301 |
+
|
302 |
+
|
303 |
+
def expand(self, *sizes):
|
304 |
+
if not _contains_dim(sizes):
|
305 |
+
return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
|
306 |
+
dims = sizes
|
307 |
+
sizes = [d.size for d in dims] + [-1] * self.ndim
|
308 |
+
self = self.expand(*sizes)
|
309 |
+
return self[dims]
|
310 |
+
|
311 |
+
|
312 |
+
_not_present = object()
|
313 |
+
|
314 |
+
|
315 |
+
def _getarg(name, offset, args, kwargs, default):
|
316 |
+
if len(args) > offset:
|
317 |
+
return args[offset]
|
318 |
+
return kwargs.get(name, default)
|
319 |
+
|
320 |
+
|
321 |
+
def _patcharg(name, offset, args, kwargs, value):
|
322 |
+
if len(args) > offset:
|
323 |
+
args[offset] = value
|
324 |
+
else:
|
325 |
+
kwargs[name] = value
|
326 |
+
|
327 |
+
|
328 |
+
def _wrap(
|
329 |
+
orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
|
330 |
+
):
|
331 |
+
from . import Dim, Tensor, TensorLike
|
332 |
+
|
333 |
+
def fn(self, *args, **kwargs):
|
334 |
+
dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
|
335 |
+
if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
|
336 |
+
with _enable_layers(self.dims):
|
337 |
+
print(f"dim fallback batch_tensor for {orig}")
|
338 |
+
return Tensor.from_batched(
|
339 |
+
orig(self._batchtensor, *args, **kwargs), self._has_device
|
340 |
+
)
|
341 |
+
keepdim = (
|
342 |
+
_getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
|
343 |
+
)
|
344 |
+
t, levels = self._tensor, llist(self._levels)
|
345 |
+
dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
|
346 |
+
dim_indices = tuple(levels.index(d) for d in dims)
|
347 |
+
if reduce and not keepdim:
|
348 |
+
new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
|
349 |
+
else:
|
350 |
+
new_levels = levels
|
351 |
+
|
352 |
+
if len(dim_indices) == 1:
|
353 |
+
dim_indices = dim_indices[
|
354 |
+
0
|
355 |
+
] # so that dims that really only take a single argument work...
|
356 |
+
args = list(args)
|
357 |
+
_patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
|
358 |
+
|
359 |
+
def wrap(t):
|
360 |
+
if isinstance(t, TensorLike):
|
361 |
+
return Tensor.from_positional(t, new_levels, self._has_device)
|
362 |
+
return t
|
363 |
+
|
364 |
+
with _enable_layers(new_levels):
|
365 |
+
print(f"dim used batch_tensor for {orig}")
|
366 |
+
r = orig(t, *args, **kwargs)
|
367 |
+
return tree_map(wrap, r)
|
368 |
+
|
369 |
+
return fn
|
370 |
+
|
371 |
+
|
372 |
+
def _def(name, *args, **kwargs):
|
373 |
+
from . import _Tensor
|
374 |
+
|
375 |
+
orig = getattr(torch.Tensor, name)
|
376 |
+
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
|
377 |
+
|
378 |
+
|
379 |
+
no_slice = slice(None)
|
380 |
+
|
381 |
+
_orig_getitem = torch.Tensor.__getitem__
|
382 |
+
|
383 |
+
|
384 |
+
class dim_tracker:
|
385 |
+
def __init__(self):
|
386 |
+
self.dims = llist()
|
387 |
+
self.count = []
|
388 |
+
|
389 |
+
def record(self, d):
|
390 |
+
if d not in self.dims:
|
391 |
+
self.dims.append(d)
|
392 |
+
self.count.append(1)
|
393 |
+
|
394 |
+
def __getitem__(self, d):
|
395 |
+
return self.count[self.dims.index(d)]
|
396 |
+
|
397 |
+
|
398 |
+
def t__getitem__(self, input):
|
399 |
+
from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
|
400 |
+
|
401 |
+
# * bail to original example if we have a single non-Dim tensor, or a non-tensor
|
402 |
+
# * locate ... or an unbound tensor list, and determine its size, bind dim list
|
403 |
+
# (remember that None does not count to the total dim count)
|
404 |
+
# * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
|
405 |
+
# produce the re-view if needed
|
406 |
+
# * for each single-use dim index, replace with no_slice and mark that it will be added
|
407 |
+
# (keep track of whether we have to call super)
|
408 |
+
# * call super if needed
|
409 |
+
# * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
|
410 |
+
|
411 |
+
# this handles bool indexing handling, as well as some other simple cases.
|
412 |
+
|
413 |
+
is_simple = (
|
414 |
+
not isinstance(input, Dim)
|
415 |
+
and not isinstance(input, (tuple, list))
|
416 |
+
and
|
417 |
+
# WAR for functorch bug where zero time tensors in getitem are not handled correctly.
|
418 |
+
not (isinstance(input, TensorLike) and input.ndim == 0)
|
419 |
+
)
|
420 |
+
|
421 |
+
if is_simple:
|
422 |
+
if isinstance(self, _Tensor):
|
423 |
+
return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
|
424 |
+
else:
|
425 |
+
return _orig_getitem(self, input)
|
426 |
+
|
427 |
+
# can further optimize this case
|
428 |
+
if not isinstance(input, tuple):
|
429 |
+
input = [input]
|
430 |
+
else:
|
431 |
+
input = list(input)
|
432 |
+
|
433 |
+
dims_indexed = 0
|
434 |
+
expanding_object = None
|
435 |
+
dimlists = []
|
436 |
+
for i, s in enumerate(input):
|
437 |
+
if s is ... or isinstance(s, DimList) and not s.is_bound:
|
438 |
+
if expanding_object is not None:
|
439 |
+
msg = (
|
440 |
+
"at most one ... or unbound dimension list can exist in indexing list but"
|
441 |
+
f" found 2 at offsets {i} and {expanding_object}"
|
442 |
+
)
|
443 |
+
raise DimensionBindError(msg)
|
444 |
+
expanding_object = i
|
445 |
+
|
446 |
+
if isinstance(s, DimList):
|
447 |
+
dims_indexed += len(s) if s.is_bound else 0
|
448 |
+
dimlists.append(i)
|
449 |
+
elif s is not None and s is not ...:
|
450 |
+
dims_indexed += 1
|
451 |
+
|
452 |
+
ndim = self.ndim
|
453 |
+
if dims_indexed > ndim:
|
454 |
+
raise IndexError(
|
455 |
+
f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
|
456 |
+
)
|
457 |
+
if expanding_object is not None:
|
458 |
+
expanding_ndims = ndim - dims_indexed
|
459 |
+
obj = input[expanding_object]
|
460 |
+
if obj is ...:
|
461 |
+
input[expanding_object : expanding_object + 1] = [
|
462 |
+
no_slice
|
463 |
+
] * expanding_ndims
|
464 |
+
else:
|
465 |
+
obj.bind_len(expanding_ndims)
|
466 |
+
# flatten the dimslists into the indexing
|
467 |
+
for i in reversed(dimlists):
|
468 |
+
input[i : i + 1] = input[i]
|
469 |
+
dims_indexed = 0
|
470 |
+
requires_view = False
|
471 |
+
size = self.size()
|
472 |
+
view_sizes = []
|
473 |
+
dims_seen = dim_tracker()
|
474 |
+
|
475 |
+
def add_dims(t):
|
476 |
+
if not isinstance(t, _Tensor):
|
477 |
+
return
|
478 |
+
for d in t.dims:
|
479 |
+
dims_seen.record(d)
|
480 |
+
|
481 |
+
add_dims(self)
|
482 |
+
dim_packs = []
|
483 |
+
for i, idx in enumerate(input):
|
484 |
+
if idx is None:
|
485 |
+
input[i] = no_slice
|
486 |
+
view_sizes.append(1)
|
487 |
+
requires_view = True
|
488 |
+
else:
|
489 |
+
sz = size[dims_indexed]
|
490 |
+
if isinstance(idx, Dim):
|
491 |
+
idx.size = sz
|
492 |
+
dims_seen.record(idx)
|
493 |
+
view_sizes.append(sz)
|
494 |
+
elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
|
495 |
+
for d in idx:
|
496 |
+
dims_seen.record(idx)
|
497 |
+
_bind_dims_to_size(sz, idx, f"offset {i}")
|
498 |
+
view_sizes.extend(d.size for d in idx)
|
499 |
+
requires_view = True
|
500 |
+
dim_packs.append(i)
|
501 |
+
else:
|
502 |
+
add_dims(idx)
|
503 |
+
view_sizes.append(sz)
|
504 |
+
dims_indexed += 1
|
505 |
+
if requires_view:
|
506 |
+
self = self.view(*view_sizes)
|
507 |
+
for i in reversed(dim_packs):
|
508 |
+
input[i : i + 1] = input[i]
|
509 |
+
|
510 |
+
# currenty:
|
511 |
+
# input is flat, containing either Dim, or Tensor, or something valid for standard indexing
|
512 |
+
# self may have first-class dims as well.
|
513 |
+
|
514 |
+
# to index:
|
515 |
+
# drop the first class dims from self, they just become direct indices of their positions
|
516 |
+
|
517 |
+
# figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
|
518 |
+
# these dimensions will appear and need to be bound at the first place tensor occures
|
519 |
+
|
520 |
+
if isinstance(self, _Tensor):
|
521 |
+
ptensor_self, levels = self._tensor, list(self._levels)
|
522 |
+
# indices to ptensor rather than self which has first-class dimensions
|
523 |
+
input_it = iter(input)
|
524 |
+
flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
|
525 |
+
has_device = self._has_device
|
526 |
+
to_pad = 0
|
527 |
+
else:
|
528 |
+
ptensor_self, flat_inputs = self, input
|
529 |
+
to_pad = ptensor_self.ndim - len(flat_inputs)
|
530 |
+
has_device = True
|
531 |
+
|
532 |
+
result_levels = []
|
533 |
+
index_levels = []
|
534 |
+
tensor_insert_point = None
|
535 |
+
to_expand = {}
|
536 |
+
requires_getindex = False
|
537 |
+
for i, inp in enumerate(flat_inputs):
|
538 |
+
if isinstance(inp, Dim) and dims_seen[inp] == 1:
|
539 |
+
flat_inputs[i] = no_slice
|
540 |
+
result_levels.append(inp)
|
541 |
+
elif isinstance(inp, TensorLike):
|
542 |
+
requires_getindex = True
|
543 |
+
if tensor_insert_point is None:
|
544 |
+
tensor_insert_point = len(result_levels)
|
545 |
+
ptensor, levels, _ = _tensor_levels(inp)
|
546 |
+
to_expand[i] = levels
|
547 |
+
flat_inputs[i] = ptensor
|
548 |
+
for l in levels:
|
549 |
+
if l not in index_levels:
|
550 |
+
index_levels.append(l)
|
551 |
+
else:
|
552 |
+
requires_getindex = True
|
553 |
+
result_levels.append(0)
|
554 |
+
|
555 |
+
if tensor_insert_point is not None:
|
556 |
+
result_levels[tensor_insert_point:tensor_insert_point] = index_levels
|
557 |
+
|
558 |
+
for i, levels in to_expand.items():
|
559 |
+
flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
|
560 |
+
|
561 |
+
if requires_getindex:
|
562 |
+
result = _orig_getitem(ptensor_self, flat_inputs)
|
563 |
+
else:
|
564 |
+
result = ptensor_self
|
565 |
+
|
566 |
+
next_positional = -1
|
567 |
+
if to_pad > 0:
|
568 |
+
result_levels.extend([0] * to_pad)
|
569 |
+
for i, r in enumerate(reversed(result_levels)):
|
570 |
+
if isinstance(r, int):
|
571 |
+
result_levels[-1 - i] = next_positional
|
572 |
+
next_positional -= 1
|
573 |
+
|
574 |
+
return Tensor.from_positional(result, result_levels, has_device)
|
575 |
+
|
576 |
+
|
577 |
+
# XXX - dim is optional and can be the outer-most dimension...
|
578 |
+
def stack(tensors, new_dim, dim=0, out=None):
|
579 |
+
if isinstance(dim, int):
|
580 |
+
return torch.stack(tensors, dim, out).index(dim, new_dim)
|
581 |
+
index = None
|
582 |
+
if out is not None:
|
583 |
+
out, index = _positional_no_permute(out, dim, expand_dim=True)
|
584 |
+
ptensors = []
|
585 |
+
for t in tensors:
|
586 |
+
pt, pi = _positional_no_permute(t, dim, expand_dim=True)
|
587 |
+
if index is not None and pi != index:
|
588 |
+
pt = pt.move_dim(pi, index)
|
589 |
+
else:
|
590 |
+
index = pi
|
591 |
+
ptensors.append(pt)
|
592 |
+
pr = torch.stack(ptensors, index, out=out)
|
593 |
+
return pr.index((index, index + 1), (new_dim, dim))
|
594 |
+
|
595 |
+
|
596 |
+
_orig_split = torch.Tensor.split
|
597 |
+
|
598 |
+
|
599 |
+
def split(self, split_size_or_sections, dim=0):
|
600 |
+
from . import _Tensor, Dim
|
601 |
+
|
602 |
+
if isinstance(split_size_or_sections, int) or any(
|
603 |
+
isinstance(t, int) for t in split_size_or_sections
|
604 |
+
):
|
605 |
+
if isinstance(dim, Dim):
|
606 |
+
raise ValueError(
|
607 |
+
"when dim is specified as a Dim object, split sizes must also be dimensions."
|
608 |
+
)
|
609 |
+
return _orig_split(self, split_size_or_sections, dim=dim)
|
610 |
+
|
611 |
+
if isinstance(dim, Dim):
|
612 |
+
assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
|
613 |
+
self, dim = _positional_no_permute(self, dim)
|
614 |
+
|
615 |
+
size = self.size(dim)
|
616 |
+
total_bound_size = 0
|
617 |
+
unbound = []
|
618 |
+
sizes = []
|
619 |
+
for i, d in enumerate(split_size_or_sections):
|
620 |
+
if d.is_bound:
|
621 |
+
sizes.append(d.size)
|
622 |
+
total_bound_size += d.size
|
623 |
+
else:
|
624 |
+
sizes.append(0)
|
625 |
+
unbound.append(i)
|
626 |
+
|
627 |
+
if unbound:
|
628 |
+
assert (
|
629 |
+
total_bound_size <= size
|
630 |
+
), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
631 |
+
remaining_size = size - total_bound_size
|
632 |
+
chunk_size = -(-remaining_size // len(unbound))
|
633 |
+
for u in unbound:
|
634 |
+
sz = min(chunk_size, remaining_size)
|
635 |
+
split_size_or_sections[u].size = sz
|
636 |
+
sizes[u] = sz
|
637 |
+
remaining_size -= sz
|
638 |
+
else:
|
639 |
+
assert (
|
640 |
+
total_bound_size == size
|
641 |
+
), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
|
642 |
+
return tuple(
|
643 |
+
t.index(dim, d)
|
644 |
+
for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
|
645 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/dim/tree_map.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from functorch._C import dim
|
8 |
+
|
9 |
+
tree_flatten = dim.tree_flatten
|
10 |
+
|
11 |
+
|
12 |
+
def tree_map(fn, tree):
|
13 |
+
vs, unflatten = tree_flatten(tree)
|
14 |
+
return unflatten(fn(v) for v in vs)
|
gtm/lib/python3.12/site-packages/functorch/dim/wrap_type.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from types import (
|
8 |
+
BuiltinMethodType,
|
9 |
+
FunctionType,
|
10 |
+
GetSetDescriptorType,
|
11 |
+
MethodDescriptorType,
|
12 |
+
WrapperDescriptorType,
|
13 |
+
)
|
14 |
+
|
15 |
+
from functorch._C import dim as _C
|
16 |
+
|
17 |
+
_wrap_method = _C._wrap_method
|
18 |
+
|
19 |
+
FUNC_TYPES = (
|
20 |
+
FunctionType,
|
21 |
+
MethodDescriptorType,
|
22 |
+
BuiltinMethodType,
|
23 |
+
WrapperDescriptorType,
|
24 |
+
)
|
25 |
+
PROPERTY_TYPES = (GetSetDescriptorType, property)
|
26 |
+
|
27 |
+
|
28 |
+
def _py_wrap_method(orig, __torch_function__):
|
29 |
+
def impl(*args, **kwargs):
|
30 |
+
return __torch_function__(orig, None, args, kwargs)
|
31 |
+
|
32 |
+
return impl
|
33 |
+
|
34 |
+
|
35 |
+
def wrap_type(use_c, to_patch, pattern, __torch_function__):
|
36 |
+
if use_c:
|
37 |
+
wrap_method = _wrap_method
|
38 |
+
else:
|
39 |
+
wrap_method = _py_wrap_method
|
40 |
+
|
41 |
+
all = {}
|
42 |
+
for t in reversed(pattern.mro()[:-1]): # skip object
|
43 |
+
all.update(t.__dict__)
|
44 |
+
|
45 |
+
def wrap_attr(orig):
|
46 |
+
return property(wrap_method(orig.__get__, __torch_function__))
|
47 |
+
|
48 |
+
for name, obj in all.items():
|
49 |
+
if name in (
|
50 |
+
"__dict__",
|
51 |
+
"__new__",
|
52 |
+
"__init__",
|
53 |
+
"__repr__",
|
54 |
+
"__weakref__",
|
55 |
+
"__doc__",
|
56 |
+
"__module__",
|
57 |
+
"__dir__",
|
58 |
+
):
|
59 |
+
continue
|
60 |
+
|
61 |
+
# skip things that have been overloaded
|
62 |
+
# things that come from object like `__eq__` still need to be patched, however.
|
63 |
+
if hasattr(to_patch, name) and getattr(to_patch, name) is not getattr(
|
64 |
+
object, name, None
|
65 |
+
):
|
66 |
+
continue
|
67 |
+
|
68 |
+
if isinstance(obj, FUNC_TYPES):
|
69 |
+
setattr(to_patch, name, wrap_method(obj, __torch_function__))
|
70 |
+
elif isinstance(obj, PROPERTY_TYPES):
|
71 |
+
setattr(to_patch, name, wrap_attr(obj))
|
gtm/lib/python3.12/site-packages/functorch/einops/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .rearrange import rearrange
|
2 |
+
|
3 |
+
__all__ = ["rearrange"]
|
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (258 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/_parsing.cpython-312.pyc
ADDED
Binary file (13 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/einops/__pycache__/rearrange.cpython-312.pyc
ADDED
Binary file (9.82 kB). View file
|
|
gtm/lib/python3.12/site-packages/functorch/einops/_parsing.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
|
2 |
+
|
3 |
+
MIT License
|
4 |
+
|
5 |
+
Copyright (c) 2018 Alex Rogozhnikov
|
6 |
+
|
7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
+
of this software and associated documentation files (the "Software"), to deal
|
9 |
+
in the Software without restriction, including without limitation the rights
|
10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
+
copies of the Software, and to permit persons to whom the Software is
|
12 |
+
furnished to do so, subject to the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be included in all
|
15 |
+
copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
"""
|
25 |
+
from __future__ import annotations
|
26 |
+
|
27 |
+
import keyword
|
28 |
+
import warnings
|
29 |
+
from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
30 |
+
|
31 |
+
_ellipsis: str = "…" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
32 |
+
|
33 |
+
|
34 |
+
class AnonymousAxis:
|
35 |
+
"""Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
|
36 |
+
|
37 |
+
Note: Different instances of this class are not equal to each other, even if they have the same value.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, value: str) -> None:
|
41 |
+
self.value = int(value)
|
42 |
+
if self.value < 1:
|
43 |
+
raise ValueError(
|
44 |
+
f"Anonymous axis should have positive length, not {self.value}"
|
45 |
+
)
|
46 |
+
|
47 |
+
def __repr__(self) -> str:
|
48 |
+
return f"{self.value}-axis"
|
49 |
+
|
50 |
+
|
51 |
+
class ParsedExpression:
|
52 |
+
"""Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
expression: str,
|
57 |
+
*,
|
58 |
+
allow_underscore: bool = False,
|
59 |
+
allow_duplicates: bool = False,
|
60 |
+
) -> None:
|
61 |
+
"""Parse the expression and store relevant metadata.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
expression (str): the `einops`-pattern to parse
|
65 |
+
allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
|
66 |
+
allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
|
67 |
+
"""
|
68 |
+
self.has_ellipsis: bool = False
|
69 |
+
self.has_ellipsis_parenthesized: Optional[bool] = None
|
70 |
+
self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
71 |
+
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
72 |
+
self.has_non_unitary_anonymous_axes: bool = False
|
73 |
+
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
74 |
+
self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
75 |
+
if "." in expression:
|
76 |
+
if "..." not in expression:
|
77 |
+
raise ValueError(
|
78 |
+
"Expression may contain dots only inside ellipsis (...)"
|
79 |
+
)
|
80 |
+
if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
|
81 |
+
raise ValueError(
|
82 |
+
"Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
|
83 |
+
)
|
84 |
+
expression = expression.replace("...", _ellipsis)
|
85 |
+
self.has_ellipsis = True
|
86 |
+
|
87 |
+
bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
88 |
+
|
89 |
+
def add_axis_name(x: str) -> None:
|
90 |
+
if x in self.identifiers:
|
91 |
+
if not (allow_underscore and x == "_") and not allow_duplicates:
|
92 |
+
raise ValueError(
|
93 |
+
f"Indexing expression contains duplicate dimension '{x}'"
|
94 |
+
)
|
95 |
+
if x == _ellipsis:
|
96 |
+
self.identifiers.add(_ellipsis)
|
97 |
+
if bracket_group is None:
|
98 |
+
self.composition.append(_ellipsis)
|
99 |
+
self.has_ellipsis_parenthesized = False
|
100 |
+
else:
|
101 |
+
bracket_group.append(_ellipsis)
|
102 |
+
self.has_ellipsis_parenthesized = True
|
103 |
+
else:
|
104 |
+
is_number = str.isdecimal(x)
|
105 |
+
if is_number and int(x) == 1:
|
106 |
+
# handling the case of anonymous axis of length 1
|
107 |
+
if bracket_group is None:
|
108 |
+
self.composition.append([])
|
109 |
+
else:
|
110 |
+
pass # no need to think about 1s inside parenthesis
|
111 |
+
return
|
112 |
+
is_axis_name, reason = self.check_axis_name_return_reason(
|
113 |
+
x, allow_underscore=allow_underscore
|
114 |
+
)
|
115 |
+
if not (is_number or is_axis_name):
|
116 |
+
raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
|
117 |
+
axis_name: Union[str, AnonymousAxis] = (
|
118 |
+
AnonymousAxis(x) if is_number else x
|
119 |
+
)
|
120 |
+
self.identifiers.add(axis_name)
|
121 |
+
if is_number:
|
122 |
+
self.has_non_unitary_anonymous_axes = True
|
123 |
+
if bracket_group is None:
|
124 |
+
self.composition.append([axis_name])
|
125 |
+
else:
|
126 |
+
bracket_group.append(axis_name)
|
127 |
+
|
128 |
+
current_identifier = None
|
129 |
+
for char in expression:
|
130 |
+
if char in "() ":
|
131 |
+
if current_identifier is not None:
|
132 |
+
add_axis_name(current_identifier)
|
133 |
+
current_identifier = None
|
134 |
+
if char == "(":
|
135 |
+
if bracket_group is not None:
|
136 |
+
raise ValueError(
|
137 |
+
"Axis composition is one-level (brackets inside brackets not allowed)"
|
138 |
+
)
|
139 |
+
bracket_group = []
|
140 |
+
elif char == ")":
|
141 |
+
if bracket_group is None:
|
142 |
+
raise ValueError("Brackets are not balanced")
|
143 |
+
self.composition.append(bracket_group)
|
144 |
+
bracket_group = None
|
145 |
+
elif str.isalnum(char) or char in ["_", _ellipsis]:
|
146 |
+
if current_identifier is None:
|
147 |
+
current_identifier = char
|
148 |
+
else:
|
149 |
+
current_identifier += char
|
150 |
+
else:
|
151 |
+
raise ValueError(f"Unknown character '{char}'")
|
152 |
+
|
153 |
+
if bracket_group is not None:
|
154 |
+
raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
|
155 |
+
if current_identifier is not None:
|
156 |
+
add_axis_name(current_identifier)
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def check_axis_name_return_reason(
|
160 |
+
name: str, allow_underscore: bool = False
|
161 |
+
) -> Tuple[bool, str]:
|
162 |
+
"""Check if the given axis name is valid, and a message explaining why if not.
|
163 |
+
|
164 |
+
Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
name (str): the axis name to check
|
168 |
+
allow_underscore (bool): whether axis names are allowed to start with an underscore
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
172 |
+
"""
|
173 |
+
if not str.isidentifier(name):
|
174 |
+
return False, "not a valid python identifier"
|
175 |
+
elif name[0] == "_" or name[-1] == "_":
|
176 |
+
if name == "_" and allow_underscore:
|
177 |
+
return True, ""
|
178 |
+
return False, "axis name should should not start or end with underscore"
|
179 |
+
else:
|
180 |
+
if keyword.iskeyword(name):
|
181 |
+
warnings.warn(
|
182 |
+
f"It is discouraged to use axes names that are keywords: {name}",
|
183 |
+
RuntimeWarning,
|
184 |
+
)
|
185 |
+
if name in ["axis"]:
|
186 |
+
warnings.warn(
|
187 |
+
"It is discouraged to use 'axis' as an axis name and will raise an error in future",
|
188 |
+
FutureWarning,
|
189 |
+
)
|
190 |
+
return True, ""
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def check_axis_name(name: str) -> bool:
|
194 |
+
"""Check if the name is a valid axis name.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
name (str): the axis name to check
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
bool: whether the axis name is valid
|
201 |
+
"""
|
202 |
+
is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
|
203 |
+
return is_valid
|
204 |
+
|
205 |
+
|
206 |
+
def parse_pattern(
|
207 |
+
pattern: str, axes_lengths: Mapping[str, int]
|
208 |
+
) -> Tuple[ParsedExpression, ParsedExpression]:
|
209 |
+
"""Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
pattern (str): the `einops`-style rearrangement pattern
|
213 |
+
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
217 |
+
"""
|
218 |
+
# adapted from einops.einops._prepare_transformation_recipe
|
219 |
+
# https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
|
220 |
+
try:
|
221 |
+
left_str, right_str = pattern.split("->")
|
222 |
+
except ValueError:
|
223 |
+
raise ValueError("Pattern must contain a single '->' separator") from None
|
224 |
+
|
225 |
+
if _ellipsis in axes_lengths:
|
226 |
+
raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
|
227 |
+
|
228 |
+
left = ParsedExpression(left_str)
|
229 |
+
right = ParsedExpression(right_str)
|
230 |
+
|
231 |
+
if not left.has_ellipsis and right.has_ellipsis:
|
232 |
+
raise ValueError(
|
233 |
+
f"Ellipsis found in right side, but not left side of a pattern {pattern}"
|
234 |
+
)
|
235 |
+
if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
236 |
+
raise ValueError(
|
237 |
+
f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
|
238 |
+
)
|
239 |
+
|
240 |
+
return left, right
|
241 |
+
|
242 |
+
|
243 |
+
def validate_rearrange_expressions(
|
244 |
+
left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
|
245 |
+
) -> None:
|
246 |
+
"""Perform expression validations that are specific to the `rearrange` operation.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
left (ParsedExpression): left-hand side expression
|
250 |
+
right (ParsedExpression): right-hand side expression
|
251 |
+
axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
252 |
+
"""
|
253 |
+
for length in axes_lengths.values():
|
254 |
+
if (length_type := type(length)) is not int:
|
255 |
+
raise TypeError(
|
256 |
+
f"rearrange axis lengths must be integers, got: {length_type}"
|
257 |
+
)
|
258 |
+
|
259 |
+
if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
|
260 |
+
raise ValueError("rearrange only supports unnamed axes of size 1")
|
261 |
+
|
262 |
+
difference = set.symmetric_difference(left.identifiers, right.identifiers)
|
263 |
+
if len(difference) > 0:
|
264 |
+
raise ValueError(
|
265 |
+
f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
|
266 |
+
)
|
267 |
+
|
268 |
+
unmatched_axes = axes_lengths.keys() - left.identifiers
|
269 |
+
if len(unmatched_axes) > 0:
|
270 |
+
raise ValueError(
|
271 |
+
f"Identifiers not found in rearrange expression: {unmatched_axes}"
|
272 |
+
)
|
273 |
+
|
274 |
+
|
275 |
+
def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
276 |
+
"""Convert a collection of strings representing first class dims into a comma-separated string.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
str: the comma-separated string
|
283 |
+
|
284 |
+
Examples:
|
285 |
+
>>> comma_separate(('d0',))
|
286 |
+
'd0'
|
287 |
+
|
288 |
+
>>> comma_separate(('d0', 'd1', 'd2', 'd3'))
|
289 |
+
'd0, d1, d2, d3'
|
290 |
+
|
291 |
+
>>> comma_separate([('d1', 'd4')])
|
292 |
+
'(d1, d4)'
|
293 |
+
|
294 |
+
>>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
|
295 |
+
'(d0,), (), (d1,), (d2,), (d3, d4)'
|
296 |
+
"""
|
297 |
+
return ", ".join(
|
298 |
+
item
|
299 |
+
if isinstance(item, str)
|
300 |
+
else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
|
301 |
+
for item in collection
|
302 |
+
)
|
gtm/lib/python3.12/site-packages/functorch/einops/rearrange.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from typing import Callable, Dict, List, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from functorch._C import dim as _C
|
9 |
+
from ._parsing import (
|
10 |
+
_ellipsis,
|
11 |
+
AnonymousAxis,
|
12 |
+
comma_separate,
|
13 |
+
parse_pattern,
|
14 |
+
validate_rearrange_expressions,
|
15 |
+
)
|
16 |
+
|
17 |
+
__all__ = ["rearrange"]
|
18 |
+
|
19 |
+
dims = _C.dims
|
20 |
+
|
21 |
+
|
22 |
+
@functools.lru_cache(256)
|
23 |
+
def _create_rearrange_callable(
|
24 |
+
tensor_ndim: int, pattern: str, **axes_lengths: int
|
25 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
26 |
+
r"""Translate an `einops`-style pattern into a callable that performs the rearrange using first-class dimensions.
|
27 |
+
|
28 |
+
Since the an equivalent result is computed for tensors with the same number of dimensions, with the same pattern and
|
29 |
+
specified axes lengths, this function can be memoized.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
tensor_ndim (int): the number of dimensions in the tensor to rearrange
|
33 |
+
pattern (str): the `einops`-style rearrangement pattern
|
34 |
+
axes_lengths (int): any additional length specifications for dimensions
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Callable[[torch.Tensor], torch.Tensor]: a callable that performs the rearrangement
|
38 |
+
"""
|
39 |
+
left, right = parse_pattern(pattern, axes_lengths)
|
40 |
+
validate_rearrange_expressions(left, right, axes_lengths)
|
41 |
+
|
42 |
+
n_anon_dims = sum(not dim for dim in left.composition)
|
43 |
+
if left.has_ellipsis:
|
44 |
+
n_ellipsis_dims = tensor_ndim - (len(left.composition) - 1)
|
45 |
+
n_named_dims = len(left.identifiers) - 1
|
46 |
+
|
47 |
+
if (pattern_ndim := n_anon_dims + n_named_dims) > tensor_ndim:
|
48 |
+
raise ValueError(
|
49 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be less than or equal to the number of "
|
50 |
+
f"dimensions in the tensor ({tensor_ndim})"
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
n_ellipsis_dims = 0
|
54 |
+
n_named_dims = len(left.identifiers)
|
55 |
+
|
56 |
+
if (pattern_ndim := len(left.composition)) != tensor_ndim:
|
57 |
+
raise ValueError(
|
58 |
+
f"Number of dimensions in pattern ({pattern_ndim}) must be equal to the number of dimensions in "
|
59 |
+
f"the tensor ({tensor_ndim})"
|
60 |
+
)
|
61 |
+
n_dims = n_named_dims + n_ellipsis_dims + n_anon_dims
|
62 |
+
|
63 |
+
if n_dims == 0:
|
64 |
+
# an identity rearrangement on a 0-dimension tensor
|
65 |
+
return lambda tensor: tensor
|
66 |
+
|
67 |
+
first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims))
|
68 |
+
identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {}
|
69 |
+
anon_axes: List[AnonymousAxis] = []
|
70 |
+
|
71 |
+
# map the left-hand side identifiers to strings representing first class dims
|
72 |
+
dims_i = 0
|
73 |
+
for dimension in left.composition:
|
74 |
+
if isinstance(dimension, list):
|
75 |
+
for identifier in dimension:
|
76 |
+
# non-unitary anon axes are not allowed in rearrange & unitary anon axes are represented as empty lists
|
77 |
+
assert isinstance(identifier, str)
|
78 |
+
identifier_dim_map[identifier] = (first_class_dims[dims_i],)
|
79 |
+
dims_i += 1
|
80 |
+
if not dimension:
|
81 |
+
# unitary anonymous axis
|
82 |
+
anon_axis = AnonymousAxis("1")
|
83 |
+
identifier_dim_map[anon_axis] = (first_class_dims[dims_i],)
|
84 |
+
anon_axes.append(anon_axis)
|
85 |
+
dimension.append(anon_axis)
|
86 |
+
dims_i += 1
|
87 |
+
elif dimension == _ellipsis:
|
88 |
+
identifier = _ellipsis
|
89 |
+
identifier_dim_map[identifier] = tuple(
|
90 |
+
first_class_dims[dims_i + j] for j in range(n_ellipsis_dims)
|
91 |
+
)
|
92 |
+
dims_i += n_ellipsis_dims
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
95 |
+
|
96 |
+
def composition_to_dims(
|
97 |
+
composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]]
|
98 |
+
) -> List[Union[str, Tuple[str, ...]]]:
|
99 |
+
"""Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first
|
100 |
+
class dims."""
|
101 |
+
dim_composition: List[Union[str, Tuple[str, ...]]] = []
|
102 |
+
for dimension in composition:
|
103 |
+
if isinstance(dimension, list):
|
104 |
+
dim_composition.append(
|
105 |
+
tuple(
|
106 |
+
dim
|
107 |
+
for identifier in dimension
|
108 |
+
for dim in identifier_dim_map[identifier]
|
109 |
+
)
|
110 |
+
)
|
111 |
+
elif dimension == _ellipsis:
|
112 |
+
dim_composition.extend(identifier_dim_map[_ellipsis])
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Unexpected dimension: {dimension}")
|
115 |
+
return dim_composition
|
116 |
+
|
117 |
+
left_dims = composition_to_dims(left.composition)
|
118 |
+
right_dims = composition_to_dims(right.composition)
|
119 |
+
anon_dims = tuple(identifier_dim_map[axis][0] for axis in anon_axes)
|
120 |
+
specified_lengths = tuple(
|
121 |
+
(identifier_dim_map[axis][0], length) for axis, length in axes_lengths.items()
|
122 |
+
)
|
123 |
+
|
124 |
+
custom_rearrange_callable_name = "do_rearrange"
|
125 |
+
custom_rearrange_callable_code = (
|
126 |
+
(
|
127 |
+
f"def {custom_rearrange_callable_name}(tensor):\n"
|
128 |
+
f" {comma_separate(first_class_dims)} = dims({n_dims})\n"
|
129 |
+
)
|
130 |
+
+ (
|
131 |
+
"".join(
|
132 |
+
f" {dim}.size = {length}\n" for (dim, length) in specified_lengths
|
133 |
+
)
|
134 |
+
if specified_lengths
|
135 |
+
else ""
|
136 |
+
)
|
137 |
+
+ f" tensor = tensor[{comma_separate(left_dims)}].order({comma_separate(right_dims)})\n"
|
138 |
+
+ (
|
139 |
+
f" return tensor.sum({comma_separate([anon_dims])}, keepdim=False)\n"
|
140 |
+
if anon_dims
|
141 |
+
else " return tensor\n"
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
exec(custom_rearrange_callable_code)
|
146 |
+
return locals()[custom_rearrange_callable_name]
|
147 |
+
|
148 |
+
|
149 |
+
def rearrange(
|
150 |
+
tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
151 |
+
pattern: str,
|
152 |
+
**axes_lengths: int,
|
153 |
+
) -> torch.Tensor:
|
154 |
+
r"""A native implementation of `einops.rearrange`, a reader-friendly smart element reordering for multidimensional
|
155 |
+
tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
156 |
+
stack, concatenate and other operations.
|
157 |
+
|
158 |
+
See: https://einops.rocks/api/rearrange/
|
159 |
+
|
160 |
+
Args:
|
161 |
+
tensor (Tensor or sequence of Tensor): the tensor(s) to rearrange
|
162 |
+
pattern (str): the rearrangement pattern
|
163 |
+
axes_lengths (int): any additional length specifications for dimensions
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Tensor: the rearranged tensor
|
167 |
+
|
168 |
+
Examples:
|
169 |
+
>>> # suppose we have a set of 32 images in "h w c" format (height-width-channel)
|
170 |
+
>>> images = torch.randn((32, 30, 40, 3))
|
171 |
+
|
172 |
+
>>> # stack along first (batch) axis, output is a single array
|
173 |
+
>>> rearrange(images, 'b h w c -> b h w c').shape
|
174 |
+
torch.Size([32, 30, 40, 3])
|
175 |
+
|
176 |
+
>>> # concatenate images along height (vertical axis), 960 = 32 * 30
|
177 |
+
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
178 |
+
torch.Size([960, 40, 3])
|
179 |
+
|
180 |
+
>>> # concatenated images along horizontal axis, 1280 = 32 * 40
|
181 |
+
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
182 |
+
torch.Size([30, 1280, 3])
|
183 |
+
|
184 |
+
>>> # reordered axes to "b c h w" format for deep learning
|
185 |
+
>>> rearrange(images, 'b h w c -> b c h w').shape
|
186 |
+
torch.Size([32, 3, 30, 40])
|
187 |
+
|
188 |
+
>>> # flattened each image into a vector, 3600 = 30 * 40 * 3
|
189 |
+
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
190 |
+
torch.Size([32, 3600])
|
191 |
+
|
192 |
+
>>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
193 |
+
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
194 |
+
torch.Size([128, 15, 20, 3])
|
195 |
+
|
196 |
+
>>> # space-to-depth operation
|
197 |
+
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
198 |
+
torch.Size([32, 15, 20, 12])
|
199 |
+
"""
|
200 |
+
if not isinstance(tensor, torch.Tensor):
|
201 |
+
tensor = torch.stack(tensor)
|
202 |
+
|
203 |
+
rearrange_callable = _create_rearrange_callable(
|
204 |
+
tensor.ndim, pattern, **axes_lengths
|
205 |
+
)
|
206 |
+
|
207 |
+
return rearrange_callable(tensor)
|
gtm/lib/python3.12/site-packages/functorch/experimental/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch forward-mode is not mature yet
|
2 |
+
from torch._functorch.apis import chunk_vmap
|
3 |
+
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
4 |
+
from torch._functorch.eager_transforms import hessian, jacfwd, jvp
|
5 |
+
|
6 |
+
from functorch import functionalize
|
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (505 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/control_flow.cpython-312.pyc
ADDED
Binary file (443 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/experimental/__pycache__/ops.cpython-312.pyc
ADDED
Binary file (254 Bytes). View file
|
|
gtm/lib/python3.12/site-packages/functorch/experimental/control_flow.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import cond # noqa: F401
|
2 |
+
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
3 |
+
|
4 |
+
from torch._higher_order_ops.map import ( # noqa: F401
|
5 |
+
_stack_pytree,
|
6 |
+
_unstack_pytree,
|
7 |
+
map,
|
8 |
+
)
|