Zaesar's picture
Duplicate from haoheliu/audioldm-text-to-audio-generation
23f4617
raw
history blame
10.6 kB
import json
import logging
import math
import os
import time
from contextlib import suppress
import numpy as np
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
from open_clip import LPLoss, LPMetrics, lp_gather_features
from open_clip.utils import do_mixup, get_mix_lambda
from .distributed import is_master
from .zero_shot import zero_shot_eval
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def unwrap_model(model):
if hasattr(model, "module"):
return model.module
else:
return model
def train_one_epoch(
model,
data,
epoch,
optimizer,
scaler,
scheduler,
args,
tb_writer=None,
extra_suffix="",
):
device = torch.device(args.device)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
model.train()
loss = LPLoss(args.lp_loss)
dataloader, sampler = data["train"].dataloader, data["train"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
num_batches_per_epoch = dataloader.num_batches
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
# for toy dataset
if args.dataset_type == "toy":
dataloader.dataset.generate_queue()
loss_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
step = num_batches_per_epoch * epoch + i
if isinstance(scheduler, dict):
for s in scheduler.values():
s(step)
else:
scheduler(step)
audio = batch # contains mel_spec, wavform, and longer list
class_label = batch["class_label"]
# audio = audio.to(device=device, non_blocking=True)
class_label = class_label.to(device=device, non_blocking=True)
if args.mixup:
# https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
mix_lambda = torch.from_numpy(
get_mix_lambda(0.5, len(audio["waveform"]))
).to(device)
class_label = do_mixup(class_label, mix_lambda)
else:
mix_lambda = None
data_time_m.update(time.time() - end)
if isinstance(optimizer, dict):
for o_ in optimizer.values():
o_.zero_grad()
else:
optimizer.zero_grad()
with autocast():
pred = model(audio, mix_lambda=mix_lambda, device=device)
total_loss = loss(pred, class_label)
if isinstance(optimizer, dict):
if scaler is not None:
scaler.scale(total_loss).backward()
for o_ in optimizer.values():
if args.horovod:
o_.synchronize()
scaler.unscale_(o_)
with o_.skip_synchronize():
scaler.step(o_)
else:
scaler.step(o_)
scaler.update()
else:
total_loss.backward()
for o_ in optimizer.values():
o_.step()
else:
if scaler is not None:
scaler.scale(total_loss).backward()
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i + 1
if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
if isinstance(audio, dict):
batch_size = len(audio["waveform"])
else:
batch_size = len(audio)
num_samples = batch_count * batch_size * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m.update(total_loss.item(), batch_size)
if isinstance(optimizer, dict):
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
)
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
}
else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"lr": optimizer.param_groups[0]["lr"],
}
for name, val in log_data.items():
name = f"train{extra_suffix}/{name}"
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
metrics = {}
if not args.parallel_eval:
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
# CHANGE
# zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
# metrics.update(zero_shot_metrics)
if is_master(args):
print("Evaluating...")
metric_names = args.lp_metrics.split(",")
eval_tool = LPMetrics(metric_names=metric_names)
autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
if "val" in data and (
args.val_frequency
and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
):
if args.parallel_eval:
dataloader, sampler = data["val"].dataloader, data["val"].sampler
if args.distributed and sampler is not None:
sampler.set_epoch(epoch)
samples_per_val = dataloader.num_samples
else:
dataloader = data["val"].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
eval_info = {"pred": [], "target": []}
with torch.no_grad():
for i, batch in enumerate(dataloader):
audio = batch # contains mel_spec, wavform, and longer list
class_label = batch["class_label"]
# audio = audio.to(device=device, non_blocking=True)
class_label = class_label.to(device=device, non_blocking=True)
with autocast():
pred = model(audio, device=device)
if args.parallel_eval:
pred, class_label = lp_gather_features(
pred, class_label, args.world_size, args.horovod
)
eval_info["pred"].append(pred)
eval_info["target"].append(class_label)
num_samples += class_label.shape[0]
if (i % 100) == 0: # and i != 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
)
if is_master(args):
eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()
eval_info["target"] = torch.cat(eval_info["target"], 0).cpu()
metric_dict = eval_tool.evaluate_mertics(
eval_info["pred"], eval_info["target"]
)
metrics.update(metric_dict)
if "epoch" not in metrics.keys():
metrics.update({"epoch": epoch})
if is_master(args):
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\n".join(
["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics]
)
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, "Please install wandb."
for name, val in metrics.items():
wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
return metrics
else:
return metrics