aps's picture
Commit efficientat
4848335
raw
history blame
3.07 kB
def NAME_TO_WIDTH(name):
map = {
'mn04': 0.4,
'mn05': 0.5,
'mn10': 1.0,
'mn20': 2.0,
'mn30': 3.0,
'mn40': 4.0
}
try:
w = map[name[:4]]
except:
w = 1.0
return w
import csv
# Load label
with open('efficientat/metadata/class_labels_indices.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
classes_num = len(labels)
import numpy as np
def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value):
rampup = exp_rampup(warmup)
rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value)
def wrapper(epoch):
return rampup(epoch) * rampdown(epoch)
return wrapper
def exp_rampup(rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
def wrapper(epoch):
if epoch < rampup_length:
epoch = np.clip(epoch, 0.5, rampup_length)
phase = 1.0 - epoch / rampup_length
return float(np.exp(-5.0 * phase * phase))
else:
return 1.0
return wrapper
def linear_rampdown(rampdown_length, start=0, last_value=0):
def wrapper(epoch):
if epoch <= start:
return 1.
elif epoch - start < rampdown_length:
return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length
else:
return last_value
return wrapper
import torch
def mixup(size, alpha):
rn_indices = torch.randperm(size)
lambd = np.random.beta(alpha, alpha, size).astype(np.float32)
lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1)
lam = torch.FloatTensor(lambd)
return rn_indices, lam
from torch.distributions.beta import Beta
def mixstyle(x, p=0.4, alpha=0.4, eps=1e-6, mix_labels=False):
if np.random.rand() > p:
return x
batch_size = x.size(0)
# changed from dim=[2,3] to dim=[1,3] - from channel-wise statistics to frequency-wise statistics
f_mu = x.mean(dim=[1, 3], keepdim=True)
f_var = x.var(dim=[1, 3], keepdim=True)
f_sig = (f_var + eps).sqrt() # compute instance standard deviation
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
x_normed = (x - f_mu) / f_sig # normalize input
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) # sample instance-wise convex weights
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics
if mix_labels:
return x, perm, lmda
return x