Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import Dict, List | |
from torch import Tensor | |
from torch.nn import Module | |
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc | |
class CachedComputationProtocol(ABC): | |
def get_output(self, | |
key: str, | |
modules: Dict[str, Module], | |
batch: List[Tensor], | |
outputs: Dict[str, List[Tensor]]): | |
if key in outputs: | |
return outputs[key] | |
else: | |
output = self.compute_output(key, modules, batch, outputs) | |
outputs[key] = output | |
return outputs[key] | |
def compute_output(self, | |
key: str, | |
modules: Dict[str, Module], | |
batch: List[Tensor], | |
outputs: Dict[str, List[Tensor]]) -> List[Tensor]: | |
pass | |
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc: | |
def func(modules: Dict[str, Module], | |
batch: List[Tensor], | |
outputs: Dict[str, List[Tensor]]): | |
return self.get_output(key, modules, batch, outputs)[index] | |
return func | |
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc: | |
def func(modules: Dict[str, Module], | |
batch: List[Tensor], | |
outputs: Dict[str, List[Tensor]]): | |
return self.get_output(key, modules, batch, outputs) | |
return func |