from torch import Tensor | |
from jaxtyping import ( | |
Float, | |
Int, | |
Bool | |
) | |
# jaxtyping is a misnomer, works for pytorch | |
class TorchTyping: | |
def __init__(self, abstract_dtype): | |
self.abstract_dtype = abstract_dtype | |
def __getitem__(self, shapes: str): | |
return self.abstract_dtype[Tensor, shapes] | |
Float = TorchTyping(Float) | |
Int = TorchTyping(Int) | |
Bool = TorchTyping(Bool) | |
__all__ = [ | |
Float, | |
Int, | |
Bool | |
] | |