import PIL import torch import numpy as np from PIL import Image from tqdm import tqdm import torch.nn.functional as F import torchvision.transforms as T from diffusers import LMSDiscreteScheduler, DiffusionPipeline # configurations torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" height, width = 64, 64 guidance_scale = 8 blue_loss_scale = 200 num_inference_steps = 50 elastic_transformer = T.ElasticTransform(alpha=550.0, sigma=5.0) pretrained_model_name_or_path = "segmind/tiny-sd" pipe = DiffusionPipeline.from_pretrained( pretrained_model_name_or_path, low_cpu_mem_usage = True ).to(torch_device) pipe.load_textual_inversion("sd-concepts-library/dreams") pipe.load_textual_inversion("sd-concepts-library/midjourney-style") pipe.load_textual_inversion("sd-concepts-library/moebius") pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante") pipe.load_textual_inversion("sd-concepts-library/wlop-style") concepts_mapping = { "Dream": '', "Midjourney":'', "Marc Allante": '', "Moebius": '', "Wlop": '' } def image_loss(images, method='elastic'): # elastic loss if method == 'elastic': transformed_imgs = elastic_transformer(images) error = torch.abs(transformed_imgs - images).mean() # symmetry loss - Flip the image along the width elif method == "symmetry": flipped_image = torch.flip(images, [3]) error = F.mse_loss(images, flipped_image) # saturation loss elif method == 'saturation': transformed_imgs = T.functional.adjust_saturation(images,saturation_factor = 10) error = torch.abs(transformed_imgs - images).mean() # blue loss elif method == 'blue': error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel return error HTML_TEMPLATE = """

Art Generator

Generate new art in five different styles by providing a prompt.

Midjourney
Midjourney Style
Dreams
Dreams Style
Moebius
Moebius Style
Allante
Hong Kong born artist inspired by western and eastern influences
Wlop
WLOP (Born 1987) is known for Digital Art (NFTs)
""" def get_examples(): examples = [ ['A powerful man in dreadlocks', 'Dream', 'Symmetry', 45], ['World Peace', 'Marc Allante', 'Saturation', 147], ['Storm trooper in the desert, dramatic lighting, high-detail', 'Moebius', 'Elastic', 28], ['Delicious Italian pizza on a table, a window in the background overlooking a city skyline', 'Wlop', 'Blue', 50], ] return(examples) def latents_to_pil(latents): # bath of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = pipe.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1 image = image.detach().cpu().permute(0, 2, 3, 1).numpy() image = (image * 255).round().astype("uint8") return Image.fromarray(image[0]) def generate_art(prompt, concept, method, seed): prompt = f"{prompt} in the style of {concepts_mapping[concept]}" img_no_loss = latents_to_pil(generate_image(prompt, method, seed)) img_loss = latents_to_pil(generate_image(prompt, method, seed, loss_apply=True)) return([img_no_loss, img_loss]) def generate_image(prompt, method, seed, loss_apply=False): generator = torch.manual_seed(seed) batch_size = 1 method = method.lower() # scheduler scheduler = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000) scheduler.set_timesteps(50) scheduler.timesteps = scheduler.timesteps.to(torch.float32) # text embeddings of the prompt text_input = pipe.tokenizer([prompt], padding='max_length', max_length = pipe.tokenizer.model_max_length, truncation= True, return_tensors="pt") input_ids = text_input.input_ids.to(torch_device) with torch.no_grad(): text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = pipe.tokenizer( [""] * 1, padding="max_length", max_length= max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) # random latent latents = torch.randn( (batch_size, pipe.unet.config.in_channels, height// 8, width //8), generator = generator, ) latents = latents.to(torch_device) latents = latents * scheduler.init_noise_sigma for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if loss_apply and i%5 == 0: latents = latents.detach().requires_grad_() latents_x0 = latents - sigma * noise_pred # use vae to decode the image denoised_images = pipe.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1) loss = image_loss(denoised_images, method) * blue_loss_scale cond_grad = torch.autograd.grad(loss, latents)[0] latents = latents.detach() - cond_grad * sigma**2 latents = scheduler.step(noise_pred,t, latents).prev_sample return latents