|
import math |
|
import time |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def select_device(device=""): |
|
cpu = device.lower() == "cpu" |
|
cuda = not cpu and torch.cuda.is_available() |
|
return torch.device("cuda:0" if cuda else "cpu") |
|
|
|
|
|
def time_synchronized(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
return time.time() |
|
|
|
|
|
def fuse_conv_and_bn(conv, bn): |
|
|
|
fusedconv = ( |
|
nn.Conv2d( |
|
conv.in_channels, |
|
conv.out_channels, |
|
kernel_size=conv.kernel_size, |
|
stride=conv.stride, |
|
padding=conv.padding, |
|
groups=conv.groups, |
|
bias=True, |
|
) |
|
.requires_grad_(False) |
|
.to(conv.weight.device) |
|
) |
|
|
|
|
|
w_conv = conv.weight.clone().view(conv.out_channels, -1) |
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) |
|
|
|
|
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias |
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) |
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
|
|
|
return fusedconv |
|
|
|
|
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32): |
|
|
|
if ratio == 1.0: |
|
return img |
|
else: |
|
h, w = img.shape[2:] |
|
s = (int(h * ratio), int(w * ratio)) |
|
img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) |
|
if not same_shape: |
|
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] |
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) |
|
|
|
|
|
def initialize_weights(model): |
|
for m in model.modules(): |
|
t = type(m) |
|
if t is nn.Conv2d: |
|
pass |
|
elif t is nn.BatchNorm2d: |
|
m.eps = 1e-3 |
|
m.momentum = 0.03 |
|
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: |
|
m.inplace = True |
|
|
|
|
|
def copy_attr(a, b, include=(), exclude=()): |
|
|
|
for k, v in b.__dict__.items(): |
|
if (len(include) and k not in include) or k.startswith("_") or k in exclude: |
|
continue |
|
else: |
|
setattr(a, k, v) |
|
|