Spaces:
Sleeping
Sleeping
import torch | |
def get_optimizer( | |
optimizer_name: str, latents: torch.Tensor, lr: float, nesterov: bool | |
): | |
if optimizer_name == "adam": | |
optimizer = torch.optim.Adam([latents], lr=lr, eps=1e-2) | |
elif optimizer_name == "sgd": | |
optimizer = torch.optim.SGD([latents], lr=lr, nesterov=nesterov, momentum=0.9) | |
elif optimizer_name == "lbfgs": | |
optimizer = torch.optim.LBFGS( | |
[latents], | |
lr=lr, | |
max_iter=10, | |
history_size=3, | |
line_search_fn="strong_wolfe", | |
) | |
else: | |
raise ValueError(f"Unknown optimizer {optimizer_name}") | |
return optimizer | |