|
import numpy as np
|
|
import torch
|
|
import safetensors
|
|
from safetensors.torch import save_file
|
|
import matplotlib.pyplot as plt
|
|
model = safetensors.safe_open('sd3_medium_incl_clips_t5xxlfp16.safetensors', 'pt')
|
|
keys = model.keys()
|
|
dic = {key:model.get_tensor(key) for key in keys}
|
|
parts = ['diffusion_model']
|
|
count = 0
|
|
for k in keys:
|
|
if all(i in k for i in parts):
|
|
v = dic[k]
|
|
print(f'{k}: {v.std()}')
|
|
dic[k] += torch.normal(torch.zeros_like(v)*v.mean(), torch.ones_like(v)*v.std()*.02)
|
|
count += 1
|
|
print(count)
|
|
save_file(dic, 'sd3_medium_incl_clips_t5xxlfp16.safetensors_perturbed3.safetensors', model.metadata()) |