Spaces:
Runtime error
Runtime error
File size: 3,051 Bytes
ed67bfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
nsteps = 50
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
diffuser.train()
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
del diffuser.vae
del diffuser.text_encoder
del diffuser.tokenizer
torch.cuda.empty_cache()
for i in pbar:
with torch.no_grad():
diffuser.set_scheduler_timesteps(nsteps)
optimizer.zero_grad()
iteration = torch.randint(1, nsteps - 1, (1,)).item()
latents = diffuser.get_initial_latents(1, 512, 1)
with finetuner:
latents_steps, _ = diffuser.diffusion(
latents,
positive_text_embeddings,
start_iteration=0,
end_iteration=iteration,
guidance_scale=3,
show_progress=False
)
diffuser.set_scheduler_timesteps(1000)
iteration = int(iteration / nsteps * 1000)
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
with finetuner:
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
loss.backward()
optimizer.step()
torch.save(finetuner.state_dict(), save_path)
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
torch.cuda.empty_cache()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--prompt', required=True)
parser.add_argument('--modules', required=True)
parser.add_argument('--freeze_modules', nargs='+', required=True)
parser.add_argument('--save_path', required=True)
parser.add_argument('--iterations', type=int, required=True)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--negative_guidance', type=float, required=True)
train(**vars(parser.parse_args())) |