Spaces:
Running
Running
from typing import Union, Callable | |
import torch | |
class SeedSchemeFactory: | |
registry = {} | |
def get_schemes_name(cls) -> list[str]: | |
return list(cls.registry.keys()) | |
def register(cls, name: str): | |
""" | |
Register the hash scheme by name. Hash scheme must be callable. | |
Args: | |
name: name of seed scheme. | |
""" | |
def wrapper(wrapped_class): | |
if name in cls.registry: | |
print(f"Override {name} in SeedSchemeFactory") | |
cls.registry[name] = wrapped_class | |
return wrapped_class | |
return wrapper | |
def get_instance(cls, name: str, *args, **kwargs): | |
""" | |
Get the hash scheme by name. | |
Args: | |
name: name of seed scheme. | |
""" | |
if name in cls.registry: | |
return cls.registry[name](*args, **kwargs) | |
else: | |
return None | |
class SeedScheme: | |
def __call__(self, input_ids: torch.Tensor) -> int: | |
return 0 | |
from seed_schemes import * | |