Spaces:
Runtime error
Runtime error
from StableDiffuser import StableDiffuser | |
from finetuning import FineTunedModel | |
import torch | |
from tqdm import tqdm | |
def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path): | |
nsteps = 50 | |
diffuser = StableDiffuser(scheduler='DDIM').to('cuda') | |
diffuser.train() | |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) | |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
criteria = torch.nn.MSELoss() | |
pbar = tqdm(range(iterations)) | |
with torch.no_grad(): | |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1) | |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1) | |
del diffuser.vae | |
del diffuser.text_encoder | |
del diffuser.tokenizer | |
torch.cuda.empty_cache() | |
for i in pbar: | |
with torch.no_grad(): | |
diffuser.set_scheduler_timesteps(nsteps) | |
optimizer.zero_grad() | |
iteration = torch.randint(1, nsteps - 1, (1,)).item() | |
latents = diffuser.get_initial_latents(1, 512, 1) | |
with finetuner: | |
latents_steps, _ = diffuser.diffusion( | |
latents, | |
positive_text_embeddings, | |
start_iteration=0, | |
end_iteration=iteration, | |
guidance_scale=3, | |
show_progress=False | |
) | |
diffuser.set_scheduler_timesteps(1000) | |
iteration = int(iteration / nsteps * 1000) | |
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1) | |
with finetuner: | |
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1) | |
positive_latents.requires_grad = False | |
neutral_latents.requires_grad = False | |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs | |
loss.backward() | |
optimizer.step() | |
torch.save(finetuner.state_dict(), save_path) | |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents | |
torch.cuda.empty_cache() | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--prompt', required=True) | |
parser.add_argument('--modules', required=True) | |
parser.add_argument('--freeze_modules', nargs='+', required=True) | |
parser.add_argument('--save_path', required=True) | |
parser.add_argument('--iterations', type=int, required=True) | |
parser.add_argument('--lr', type=float, required=True) | |
parser.add_argument('--negative_guidance', type=float, required=True) | |
train(**vars(parser.parse_args())) |