Spaces:
Runtime error
Runtime error
from functools import partial | |
import torch | |
from opensora.registry import SCHEDULERS | |
from .dpm_solver import DPMS | |
class DMP_SOLVER: | |
def __init__(self, num_sampling_steps=None, cfg_scale=4.0): | |
self.num_sampling_steps = num_sampling_steps | |
self.cfg_scale = cfg_scale | |
def sample( | |
self, | |
model, | |
text_encoder, | |
z_size, | |
prompts, | |
device, | |
additional_args=None, | |
): | |
n = len(prompts) | |
z = torch.randn(n, *z_size, device=device) | |
model_args = text_encoder.encode(prompts) | |
y = model_args.pop("y") | |
null_y = text_encoder.null(n) | |
if additional_args is not None: | |
model_args.update(additional_args) | |
dpms = DPMS( | |
partial(forward_with_dpmsolver, model), | |
condition=y, | |
uncondition=null_y, | |
cfg_scale=self.cfg_scale, | |
model_kwargs=model_args, | |
) | |
samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep") | |
return samples | |
def forward_with_dpmsolver(self, x, timestep, y, **kwargs): | |
""" | |
dpm solver donnot need variance prediction | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | |
model_out = self.forward(x, timestep, y, **kwargs) | |
return model_out.chunk(2, dim=1)[0] | |