File size: 447 Bytes
96e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from torch import Tensor
from jaxtyping import (
Float,
Int,
Bool
)
# jaxtyping is a misnomer, works for pytorch
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)
__all__ = [
Float,
Int,
Bool
]
|