Spaces:
Sleeping
Sleeping
File size: 1,076 Bytes
91d5a5e 7235a64 91d5a5e 7235a64 91d5a5e 7235a64 91d5a5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from typing import Union, Callable
import torch
class SeedSchemeFactory:
registry = {}
@classmethod
def get_schemes_name(cls) -> list[str]:
return list(cls.registry.keys())
@classmethod
def register(cls, name: str):
"""
Register the hash scheme by name. Hash scheme must be callable.
Args:
name: name of seed scheme.
"""
def wrapper(wrapped_class):
if name in cls.registry:
print(f"Override {name} in SeedSchemeFactory")
cls.registry[name] = wrapped_class
return wrapped_class
return wrapper
@classmethod
def get_instance(cls, name: str, *args, **kwargs):
"""
Get the hash scheme by name.
Args:
name: name of seed scheme.
"""
if name in cls.registry:
return cls.registry[name](*args, **kwargs)
else:
return None
class SeedScheme:
def __call__(self, input_ids: torch.Tensor) -> int:
return 0
from seed_schemes import *
|