--- license: apache-2.0 language: - zh - en - fr - de - ja - kg base_model: - stabilityai/stable-diffusion-xl-base-1.0 pipeline_tag: text-to-image --- ![FLUX.1 [schnell] Grid](./PEA-Diffusion.png) Text-to-image diffusion models are well-known for their ability to generate realistic images based on textual prompts. However, the existing works have predominantly focused on English, lacking support for non-English text-to-image models. The most commonly used translation methods cannot solve the generation problem related to language culture, while training from scratch on a specific language dataset is prohibitively expensive. In this paper, we are inspired to propose a simple plug-and-play language transfer method based on knowledge distillation. All we need to do is train a lightweight MLP-like parameter-efficient adapter (PEA) with only 6M parameters under teacher knowledge distillation along with a small parallel data corpus. We are surprised to find that freezing the parameters of UNet can still achieve remarkable performance on the language-specific prompt evaluation set, demonstrating that PEA can stimulate the potential generation ability of the original UNet. Additionally, it closely approaches the performance of the English text-to-image model on a general prompt evaluation set. Furthermore, our adapter can be used as a plugin to achieve significant results in downstream tasks in cross-lingual text-to-image generation. # Usage We provide examples of adapters for models such as [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), [Playground v2.5](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), and [stable-cascade](https://huggingface.co/stabilityai/stable-cascade). For SD3, please refer directly to https://huggingface.co/OPPOer/MultilingualSD3-adapter, and for FLUX. 1, please refer to https://huggingface.co/OPPOer/MultilingualFLUX.1-adapter ## `SDXL` We used the multilingual encoder [Mul-OpenCLIP](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k). As mentioned in the article, you can replace the model here with any SDXL derived model, including sampling acceleration, which can also be directly adapted. ```python import os import torch import torch.nn as nn from PIL import Image from diffusers import AutoencoderKL, StableDiffusionXLPipeline,DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from typing import Any, Callable, Dict, List, Optional, Tuple, Union import open_clip def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid class MLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim,out_dim1, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.fc1 = nn.Linear(in_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, out_dim) self.fc3 = nn.Linear(out_dim, out_dim1) self.use_residual = use_residual self.act_fn = nn.GELU() def forward(self, x): residual = x x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) x2 = self.act_fn(x) x2 = self.fc3(x2) if self.use_residual: x = x + residual x1 = torch.mean(x,1) return x1,x2 class StableDiffusionTest(): def __init__(self, model_id,text_text_encoder_pathpath,proj_path): super().__init__() self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path) self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14') self.text_encoder.text.output_tokens = True self.text_encoder = self.text_encoder.to(device,dtype=dtype) self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") self.pipe = StableDiffusionXLPipeline.from_pretrained(model_id, scheduler=scheduler,torch_dtype=dtype).to(device) self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor) self.proj = MLP(1024, 1280, 1024,2048, use_residual=False).to(device,dtype=dtype) self.proj.load_state_dict(torch.load(proj_path, map_location="cpu")) def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_input_ids = self.tokenizer(prompt).to(device) _,text_embeddings = self.text_encoder.encode_text(text_input_ids) add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input_ids = self.tokenizer(uncond_tokens).to(device) _,uncond_embeddings = self.text_encoder.encode_text(uncond_input_ids) add_text_embeds_uncond,uncond_embeddings_2048 = self.proj(uncond_embeddings) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings_2048.shape[1] uncond_embeddings_2048 = uncond_embeddings_2048.repeat(1, num_images_per_prompt, 1) uncond_embeddings_2048 = uncond_embeddings_2048.view(batch_size * num_images_per_prompt, seq_len, -1) text_embeddings_2048 = torch.cat([uncond_embeddings_2048, text_embeddings_2048]) add_text_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds]) return text_embeddings_2048,add_text_embeds def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = 1024, width: Optional[int] = 1024, num_inference_steps: int = 30, guidance_scale: float = 7.5, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): # 0. Default height and width to unet height = height or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor width = width or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct # self.pipe.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.pipe._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds,add_text_embeds = self.encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) prompt_embeds = prompt_embeds add_text_embeds = add_text_embeds # 4. Prepare timesteps self.pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.pipe.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.pipe.unet.in_channels latents = self.pipe.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) add_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype) if do_classifier_free_guidance: add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # 7. Denoising loop for i, t in enumerate(self.pipe.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 # latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ] # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if not use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() # 8. Post-processing image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type="np") # 10. Convert to PIL if output_type == "pil": image = self.pipe.numpy_to_pil(image) return image if __name__ == '__main__': device = "cuda" dtype = torch.float16 text_encoder_path = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin' model_id = "stablediffusionapi/protovision-xl-v6.6" proj_path = "OPPOer/PEA-Diffusion/pytorch_model.bin" sdt = StableDiffusionTest(model_id,text_encoder_path,proj_path) batch=2 height = 1024 width = 1024 while True: raw_text = input("\nPlease Input Query (stop to exit) >>> ") if not raw_text: print('Query should not be empty!') continue if raw_text == "stop": break images = sdt([raw_text]*batch,height=height,width=width) grid = image_grid(images, rows=1, cols=batch) grid.save("SDXL.png") ``` ## `Playground v2.5` We used the multilingual encoder [Mul-OpenCLIP](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k) ```python import os,sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union import sys import random from tqdm import tqdm import torch import torch.nn as nn import numpy as np import argparse from PIL import Image import json from diffusers import AutoencoderKL, DiffusionPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) import open_clip def image_grid(imgs, rows, cols): assert len(imgs) == rows*cols w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid class MLP(nn.Module): def __init__(self, in_dim=1024, out_dim=1280, hidden_dim=2048, out_dim1=2048, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.projector = nn.Sequential( nn.Linear(in_dim, hidden_dim, bias=False), nn.GELU(), nn.Linear(hidden_dim, hidden_dim, bias=False), nn.GELU(), nn.Linear(hidden_dim, hidden_dim, bias=False), nn.GELU(), nn.Linear(hidden_dim, out_dim, bias=False), ) self.fc = nn.Linear(out_dim, out_dim1) self.use_residual = use_residual def forward(self, x): residual = x x = self.layernorm(x) x = self.projector(x) x2 = nn.GELU()(x) x2 = self.fc(x2) if self.use_residual: x = x + residual x1 = torch.mean(x,1) return x1,x2 class StableDiffusionTest(): def __init__(self, model_id,text_encoder_path,proj_path): super().__init__() self.text_encoder, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-large-ViT-H-14', pretrained=text_encoder_path) self.tokenizer = open_clip.get_tokenizer('xlm-roberta-large-ViT-H-14') self.text_encoder.text.output_tokens = True self.text_encoder = self.text_encoder.to(device,dtype=dtype) self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) self.pipe = DiffusionPipeline.from_pretrained(model_id, subfolder="scheduler", torch_dtype=dtype, variant="fp16").to(device) self.image_processor = VaeImageProcessor(vae_scale_factor=self.pipe.vae_scale_factor) self.proj = MLP(1024, 1280, 2048, 2048, use_residual=False).to(device,dtype=dtype) self.proj.load_state_dict(torch.load(proj_path, map_location="cpu")) def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_input_ids = self.tokenizer(prompt).to(device) _,text_embeddings = self.text_encoder.encode_text(text_input_ids) add_text_embeds,text_embeddings_2048 = self.proj(text_embeddings) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input_ids = self.tokenizer(uncond_tokens).to(device) _,uncond_embeddings = self.text_encoder.encode_text(uncond_input_ids) add_text_embeds_uncond,uncond_embeddings_2048 = self.proj(uncond_embeddings) seq_len = uncond_embeddings_2048.shape[1] uncond_embeddings_2048 = uncond_embeddings_2048.repeat(1, num_images_per_prompt, 1) uncond_embeddings_2048 = uncond_embeddings_2048.view(batch_size * num_images_per_prompt, seq_len, -1) text_embeddings_2048 = torch.cat([uncond_embeddings_2048, text_embeddings_2048]) add_text_embeds = torch.cat([add_text_embeds_uncond, add_text_embeds]) return text_embeddings_2048,add_text_embeds def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = 1024, width: Optional[int] = 1024, num_inference_steps: int = 50, guidance_scale: float = 3, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): height = height or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor width = width or self.pipe.unet.config.sample_size * self.pipe.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self.pipe._execution_device do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds,add_text_embeds = self.encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt) self.pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.pipe.scheduler.timesteps num_channels_latents = self.pipe.unet.in_channels latents = self.pipe.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) add_time_ids = self._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype) if do_classifier_free_guidance: add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} for i, t in enumerate(self.pipe.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ] if not use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type="np") if output_type == "pil": image = self.pipe.numpy_to_pil(image) return image if __name__ == '__main__': device = "cuda" dtype = torch.float16 model_id = "playgroundai/playground-v2.5-1024px-aesthetic" text_encoder_path = 'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/open_clip_pytorch_model.bin' proj_path = "OPPOer/PEA-Diffusion/pytorch_model_pg.bin" sdt = StableDiffusionTest(model_id,text_encoder_path,proj_path) batch=2 height = 1024 width = 1024 while True: raw_text = input("\nPlease Input Query (stop to exit) >>> ") if not raw_text: print('Query should not be empty!') continue if raw_text == "stop": break images = sdt([raw_text]*batch,height=height,width=width) grid = image_grid(images, rows=1, cols=batch) grid.save("PG.png") ``` To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation ## `stable-cascade` comig soon # License The adapter itself is Apache License 2.0, but it must follow the license of the main model. # Citation ``` @misc{ma2023peadiffusion, title={PEA-Diffusion: Parameter-Efficient Adapter with Knowledge Distillation in non-English Text-to-Image Generation}, author={Jian Ma and Chen Chen and Qingsong Xie and Haonan Lu}, year={2023}, eprint={2311.17086}, archivePrefix={arXiv}, primaryClass={cs.CV} } ```