VoiceRestore / tensor_typing.py
jadechoghari's picture
add initial files
96e64e9 verified
raw
history blame
447 Bytes
from torch import Tensor
from jaxtyping import (
Float,
Int,
Bool
)
# jaxtyping is a misnomer, works for pytorch
class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype
def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]
Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)
__all__ = [
Float,
Int,
Bool
]