Spaces:
Sleeping
Sleeping
from typing import Union | |
import torch | |
import io | |
from cryptography.hazmat.primitives import hashes | |
from seed_scheme_factory import SeedSchemeFactory, SeedScheme | |
class SHALeftHash(SeedScheme): | |
def __init__(self, private_key: Union[int, None] = None, *args, **kwargs): | |
self.private_key = ( | |
private_key.to_bytes(8, "big") if private_key is not None else None | |
) | |
def __call__(self, input_ids: torch.Tensor): | |
buff = io.BytesIO() | |
if self.private_key is not None: | |
buff.write(self.private_key) | |
for input_id in input_ids: | |
buff.write(int(input_id.item()).to_bytes(8, "big")) | |
buff.seek(0) | |
input_ids_bytes = buff.read() | |
digest = hashes.Hash(hashes.SHA224()) | |
digest.update(input_ids_bytes) | |
hashed_value = digest.finalize() | |
# Only take the first 8 bytes because seed in torch rng only accept int64 seed | |
seed = int.from_bytes(hashed_value[:8], byteorder="big") | |
return seed | |