Spaces:
Sleeping
Sleeping
File size: 2,490 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Any, Union
import ignite.distributed as idist
import torch
from ignite.engine import DeterministicEngine, Engine, Events
from torch.cuda.amp import autocast
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler, Sampler
def setup_trainer(
config: Any,
model: Module,
optimizer: Optimizer,
loss_fn: Module,
device: Union[str, torch.device],
train_sampler: Sampler,
) -> Union[Engine, DeterministicEngine]:
def train_function(engine: Union[Engine, DeterministicEngine], batch: Any):
if config.overfit:
# No batch norm
model.eval()
else:
model.train()
samples = batch[0].to(device, non_blocking=True)
targets = batch[1].to(device, non_blocking=True)
attack_targets = batch[2].to(device, non_blocking=True)
sample_ids = batch[3].to(device, non_blocking=True)
with autocast(config.use_amp):
outputs = model(samples, attack_targets)
loss = loss_fn(outputs, attack_targets, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss = loss.item()
engine.state.metrics = {
"epoch": engine.state.epoch,
"train_loss": train_loss,
}
return {"train_loss": train_loss}
trainer = Engine(train_function)
# set epoch for distributed sa5mpler
@trainer.on(Events.EPOCH_STARTED)
def set_epoch():
if idist.get_world_size() > 1 and isinstance(train_sampler, DistributedSampler):
train_sampler.set_epoch(trainer.state.epoch - 1)
return trainer
def setup_evaluator(
config: Any,
model: Module,
device: Union[str, torch.device],
) -> Engine:
@torch.no_grad()
def eval_function(engine: Engine, batch: Any):
model.eval()
samples, gt_labels, attack_targets, sample_ids = batch
samples = samples.to(device, non_blocking=True)
gt_labels = gt_labels.to(device, non_blocking=True)
attack_targets = attack_targets.to(device, non_blocking=True)
sample_ids = sample_ids.to(device, non_blocking=True)
with autocast(config.use_amp):
outputs, perturbations = model(samples, attack_targets, gt_labels)
return outputs, attack_targets, {
"gt_targets": gt_labels,
"perturbations": perturbations
}
return Engine(eval_function)
|