|
import os |
|
import math |
|
import torch |
|
import logging |
|
import subprocess |
|
import numpy as np |
|
import torch.distributed as dist |
|
|
|
|
|
from torch import inf |
|
from PIL import Image |
|
from typing import Union, Iterable |
|
from collections import OrderedDict |
|
from torch.utils.tensorboard import SummaryWriter |
|
from typing import Dict |
|
import torch_dct |
|
|
|
from diffusers.utils import is_bs4_available, is_ftfy_available |
|
|
|
import html |
|
import re |
|
import urllib.parse as ul |
|
|
|
if is_bs4_available(): |
|
from bs4 import BeautifulSoup |
|
|
|
if is_ftfy_available(): |
|
import ftfy |
|
|
|
import torch.fft as fft |
|
|
|
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_model(model_name): |
|
""" |
|
Finds a pre-trained model |
|
""" |
|
assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' |
|
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) |
|
|
|
if "ema" in checkpoint: |
|
print('Using ema ckpt!') |
|
checkpoint = checkpoint["ema"] |
|
else: |
|
checkpoint = checkpoint["model"] |
|
print("Using model ckpt!") |
|
return checkpoint |
|
|
|
def save_video_grid(video, nrow=None): |
|
b, t, h, w, c = video.shape |
|
|
|
if nrow is None: |
|
nrow = math.ceil(math.sqrt(b)) |
|
ncol = math.ceil(b / nrow) |
|
padding = 1 |
|
video_grid = torch.zeros((t, (padding + h) * nrow + padding, |
|
(padding + w) * ncol + padding, c), dtype=torch.uint8) |
|
|
|
|
|
for i in range(b): |
|
r = i // ncol |
|
c = i % ncol |
|
start_r = (padding + h) * r |
|
start_c = (padding + w) * c |
|
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] |
|
|
|
return video_grid |
|
|
|
def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, nrow=None, fps=8): |
|
from einops import rearrange |
|
import imageio |
|
import torchvision |
|
|
|
b, _, _, _, _ = videos.shape |
|
if nrow is None: |
|
nrow = math.ceil(math.sqrt(b)) |
|
videos = rearrange(videos, "b c t h w -> t b c h w") |
|
outputs = [] |
|
for x in videos: |
|
x = torchvision.utils.make_grid(x, nrow=nrow) |
|
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
x = (x * 255).numpy().astype(np.uint8) |
|
outputs.append(x) |
|
|
|
|
|
imageio.mimsave(path, outputs, fps=fps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_env(): |
|
|
|
from mmcv.utils import collect_env as collect_base_env |
|
from mmcv.utils import get_git_hash |
|
"""Collect the information of the running environments.""" |
|
|
|
env_info = collect_base_env() |
|
env_info['MMClassification'] = get_git_hash()[:7] |
|
|
|
for name, val in env_info.items(): |
|
print(f'{name}: {val}') |
|
|
|
print(torch.cuda.get_arch_list()) |
|
print(torch.version.cuda) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def dct_low_pass_filter(dct_coefficients, percentage=0.3): |
|
""" |
|
Applies a low pass filter to the given DCT coefficients. |
|
|
|
:param dct_coefficients: 2D tensor of DCT coefficients |
|
:param percentage: percentage of coefficients to keep (between 0 and 1) |
|
:return: 2D tensor of DCT coefficients after applying the low pass filter |
|
""" |
|
|
|
cutoff_x = int(dct_coefficients.shape[-2] * percentage) |
|
cutoff_y = int(dct_coefficients.shape[-1] * percentage) |
|
|
|
|
|
mask = torch.zeros_like(dct_coefficients) |
|
|
|
mask[:, :, :, :cutoff_x, :cutoff_y] = 1 |
|
|
|
return mask |
|
|
|
def normalize(tensor): |
|
"""将Tensor归一化到[0, 1]范围内。""" |
|
min_val = tensor.min() |
|
max_val = tensor.max() |
|
normalized = (tensor - min_val) / (max_val - min_val) |
|
return normalized |
|
|
|
def denormalize(tensor, max_val_target, min_val_target): |
|
"""将Tensor从[0, 1]范围反归一化到目标的[min_val_target, max_val_target]范围。""" |
|
denormalized = tensor * (max_val_target - min_val_target) + min_val_target |
|
return denormalized |
|
|
|
def exchanged_mixed_dct_freq(noise, base_content, LPF_3d, normalized=False): |
|
|
|
noise_freq = torch_dct.dct_3d(noise, 'ortho') |
|
|
|
|
|
HPF_3d = 1 - LPF_3d |
|
noise_freq_high = noise_freq * HPF_3d |
|
|
|
|
|
base_content_freq = torch_dct.dct_3d(base_content, 'ortho') |
|
|
|
|
|
base_content_freq_low = base_content_freq * LPF_3d |
|
|
|
|
|
mixed_freq = base_content_freq_low + noise_freq_high |
|
|
|
|
|
mixed_freq = torch_dct.idct_3d(mixed_freq, 'ortho') |
|
|
|
return mixed_freq |