Spaces:
Sleeping
Sleeping
File size: 985 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torchvision import models
from modelguidedattacks.guides.instance_guide import InstanceGuide
from modelguidedattacks.guides.unguided import Unguided
from modelguidedattacks import losses
from .cls_models.registry import get_model
guide_model_registry = {
"instance_guided": InstanceGuide,
"unguided": Unguided
}
loss_registry = {
"cvxproj": losses.CVXProjLoss,
"cwk": losses.CWExtensionLoss,
"ad": losses.AdversarialDistillationLoss
}
def setup_model(config, device):
model = get_model(config.dataset, config.model, device)
kwargs = {}
if config.guide_model == "unguided":
kwargs["iterations"] = config.unguided_iterations
kwargs["lr"] = config.unguided_lr
kwargs["loss_fn"] = loss_registry[config.loss]
kwargs["binary_search_steps"] = config.binary_search_steps
kwargs["topk_loss_coef_upper"] = config.topk_loss_coef_upper
return guide_model_registry[config.guide_model](model, config, **kwargs) |