|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Train and eval functions used in main.py |
|
""" |
|
import math |
|
import os |
|
import sys |
|
from typing import Iterable |
|
|
|
import torch |
|
import util.misc as utils |
|
|
|
from datasets.data_prefetcher import data_dict_to_cuda |
|
|
|
|
|
def train_one_epoch_mot(model: torch.nn.Module, criterion: torch.nn.Module, |
|
data_loader: Iterable, optimizer: torch.optim.Optimizer, |
|
device: torch.device, epoch: int, max_norm: float = 0): |
|
model.train() |
|
criterion.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
|
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = 10 |
|
|
|
|
|
for data_dict in metric_logger.log_every(data_loader, print_freq, header): |
|
data_dict = data_dict_to_cuda(data_dict, device) |
|
outputs = model(data_dict) |
|
|
|
loss_dict = criterion(outputs, data_dict) |
|
|
|
weight_dict = criterion.weight_dict |
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
|
|
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict) |
|
|
|
|
|
loss_dict_reduced_scaled = {k: v * weight_dict[k] |
|
for k, v in loss_dict_reduced.items() if k in weight_dict} |
|
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) |
|
|
|
loss_value = losses_reduced_scaled.item() |
|
|
|
if not math.isfinite(loss_value): |
|
print("Loss is {}, stopping training".format(loss_value)) |
|
print(loss_dict_reduced) |
|
sys.exit(1) |
|
|
|
optimizer.zero_grad() |
|
losses.backward() |
|
if max_norm > 0: |
|
grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) |
|
else: |
|
grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) |
|
optimizer.step() |
|
|
|
|
|
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled) |
|
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
metric_logger.update(grad_norm=grad_total_norm) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|