maskgct / modules /diffusion /karras /random_utils.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
6.06 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch as th
def get_generator(generator, num_samples=0, seed=0):
if generator == "dummy":
return DummyGenerator()
elif generator == "determ":
return DeterministicGenerator(num_samples, seed)
elif generator == "determ-indiv":
return DeterministicIndividualGenerator(num_samples, seed)
else:
raise NotImplementedError
class DummyGenerator:
def randn(self, *args, **kwargs):
return th.randn(*args, **kwargs)
def randint(self, *args, **kwargs):
return th.randint(*args, **kwargs)
def randn_like(self, *args, **kwargs):
return th.randn_like(*args, **kwargs)
class DeterministicGenerator:
"""
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
Uses a single rng and samples num_samples sized randomness and subsamples the current indices
"""
def __init__(self, num_samples, seed=0):
print("Warning: Distributed not initialised, using single rank")
self.rank = 0
self.world_size = 1
self.num_samples = num_samples
self.done_samples = 0
self.seed = seed
self.rng_cpu = th.Generator()
if th.cuda.is_available():
self.rng_cuda = th.Generator(dist_util.dev())
self.set_seed(seed)
def get_global_size_and_indices(self, size):
global_size = (self.num_samples, *size[1:])
indices = th.arange(
self.done_samples + self.rank,
self.done_samples + self.world_size * int(size[0]),
self.world_size,
)
indices = th.clamp(indices, 0, self.num_samples - 1)
assert (
len(indices) == size[0]
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
return global_size, indices
def get_generator(self, device):
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
def randn(self, *size, dtype=th.float, device="cpu"):
global_size, indices = self.get_global_size_and_indices(size)
generator = self.get_generator(device)
return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[
indices
]
def randint(self, low, high, size, dtype=th.long, device="cpu"):
global_size, indices = self.get_global_size_and_indices(size)
generator = self.get_generator(device)
return th.randint(
low, high, generator=generator, size=global_size, dtype=dtype, device=device
)[indices]
def randn_like(self, tensor):
size, dtype, device = tensor.size(), tensor.dtype, tensor.device
return self.randn(*size, dtype=dtype, device=device)
def set_done_samples(self, done_samples):
self.done_samples = done_samples
self.set_seed(self.seed)
def get_seed(self):
return self.seed
def set_seed(self, seed):
self.rng_cpu.manual_seed(seed)
if th.cuda.is_available():
self.rng_cuda.manual_seed(seed)
class DeterministicIndividualGenerator:
"""
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
Uses a separate rng for each sample to reduce memoery usage
"""
def __init__(self, num_samples, seed=0):
print("Warning: Distributed not initialised, using single rank")
self.rank = 0
self.world_size = 1
self.num_samples = num_samples
self.done_samples = 0
self.seed = seed
self.rng_cpu = [th.Generator() for _ in range(num_samples)]
if th.cuda.is_available():
self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)]
self.set_seed(seed)
def get_size_and_indices(self, size):
indices = th.arange(
self.done_samples + self.rank,
self.done_samples + self.world_size * int(size[0]),
self.world_size,
)
indices = th.clamp(indices, 0, self.num_samples - 1)
assert (
len(indices) == size[0]
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
return (1, *size[1:]), indices
def get_generator(self, device):
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
def randn(self, *size, dtype=th.float, device="cpu"):
size, indices = self.get_size_and_indices(size)
generator = self.get_generator(device)
return th.cat(
[
th.randn(*size, generator=generator[i], dtype=dtype, device=device)
for i in indices
],
dim=0,
)
def randint(self, low, high, size, dtype=th.long, device="cpu"):
size, indices = self.get_size_and_indices(size)
generator = self.get_generator(device)
return th.cat(
[
th.randint(
low,
high,
generator=generator[i],
size=size,
dtype=dtype,
device=device,
)
for i in indices
],
dim=0,
)
def randn_like(self, tensor):
size, dtype, device = tensor.size(), tensor.dtype, tensor.device
return self.randn(*size, dtype=dtype, device=device)
def set_done_samples(self, done_samples):
self.done_samples = done_samples
def get_seed(self):
return self.seed
def set_seed(self, seed):
[
rng_cpu.manual_seed(i + self.num_samples * seed)
for i, rng_cpu in enumerate(self.rng_cpu)
]
if th.cuda.is_available():
[
rng_cuda.manual_seed(i + self.num_samples * seed)
for i, rng_cuda in enumerate(self.rng_cuda)
]