File size: 329 Bytes
60094bd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
from typing import Callable, Dict, List

from torch import Tensor
from torch.nn import Module

TensorCachedComputationFunc = Callable[
    [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], Tensor]
TensorListCachedComputationFunc = Callable[
    [Dict[str, Module], List[Tensor], Dict[str, List[Tensor]]], List[Tensor]]