Spaces:
Sleeping
Sleeping
File size: 648 Bytes
ca25718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
def get_optimizer(
optimizer_name: str, latents: torch.Tensor, lr: float, nesterov: bool
):
if optimizer_name == "adam":
optimizer = torch.optim.Adam([latents], lr=lr, eps=1e-2)
elif optimizer_name == "sgd":
optimizer = torch.optim.SGD([latents], lr=lr, nesterov=nesterov, momentum=0.9)
elif optimizer_name == "lbfgs":
optimizer = torch.optim.LBFGS(
[latents],
lr=lr,
max_iter=10,
history_size=3,
line_search_fn="strong_wolfe",
)
else:
raise ValueError(f"Unknown optimizer {optimizer_name}")
return optimizer
|