Spaces:
Running
Running
import os | |
import json | |
import random | |
import glob | |
import torch | |
import einops | |
import torchvision | |
import safetensors.torch as sf | |
def write_to_json(data, file_path): | |
temp_file_path = file_path + ".tmp" | |
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: | |
json.dump(data, temp_file, indent=4) | |
os.replace(temp_file_path, file_path) | |
return | |
def read_from_json(file_path): | |
with open(file_path, 'rt', encoding='utf-8') as file: | |
data = json.load(file) | |
return data | |
def get_active_parameters(m): | |
return {k:v for k, v in m.named_parameters() if v.requires_grad} | |
def cast_training_params(m, dtype=torch.float32): | |
for param in m.parameters(): | |
if param.requires_grad: | |
param.data = param.to(dtype) | |
return | |
def set_attr_recursive(obj, attr, value): | |
attrs = attr.split(".") | |
for name in attrs[:-1]: | |
obj = getattr(obj, name) | |
setattr(obj, attrs[-1], value) | |
return | |
def batch_mixture(a, b, probability_a=0.5, mask_a=None): | |
assert a.shape == b.shape, "Tensors must have the same shape" | |
batch_size = a.size(0) | |
if mask_a is None: | |
mask_a = torch.rand(batch_size) < probability_a | |
mask_a = mask_a.to(a.device) | |
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) | |
result = torch.where(mask_a, a, b) | |
return result | |
def zero_module(module): | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
def load_last_state(model, folder='accelerator_output'): | |
file_pattern = os.path.join(folder, '**', 'model.safetensors') | |
files = glob.glob(file_pattern, recursive=True) | |
if not files: | |
print("No model.safetensors files found in the specified folder.") | |
return | |
newest_file = max(files, key=os.path.getmtime) | |
state_dict = sf.load_file(newest_file) | |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
if missing_keys: | |
print("Missing keys:", missing_keys) | |
if unexpected_keys: | |
print("Unexpected keys:", unexpected_keys) | |
print("Loaded model state from:", newest_file) | |
return | |
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): | |
tags = tags_str.split(', ') | |
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) | |
prompt = ', '.join(tags) | |
return prompt | |
def save_bcthw_as_mp4(x, output_filename, fps=10): | |
b, c, t, h, w = x.shape | |
per_row = b | |
for p in [6, 5, 4, 3, 2]: | |
if b % p == 0: | |
per_row = p | |
break | |
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
x = x.detach().cpu().to(torch.uint8) | |
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) | |
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) | |
return x | |
def save_bcthw_as_png(x, output_filename): | |
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) | |
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 | |
x = x.detach().cpu().to(torch.uint8) | |
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') | |
torchvision.io.write_png(x, output_filename) | |
return output_filename | |
def add_tensors_with_padding(tensor1, tensor2): | |
if tensor1.shape == tensor2.shape: | |
return tensor1 + tensor2 | |
shape1 = tensor1.shape | |
shape2 = tensor2.shape | |
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) | |
padded_tensor1 = torch.zeros(new_shape) | |
padded_tensor2 = torch.zeros(new_shape) | |
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 | |
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 | |
result = padded_tensor1 + padded_tensor2 | |
return result | |