Flux Qwen Neutered
Qwen is a lightweight alternative to the T5 model. For use with the Flux Dev model.
For the numerical stability, it requires both tokenizers from Qwen and Flux, that's 10MB of additional data.
This repo is an experimental work, and not a final replacement for the built-in text encoder.
Compared to a standalone version, this demo has improved accuracy and training time. This is mainly due to the reuse of a pre-trained model.
Inference
from diffusers import FluxPipeline, FluxTransformer2DModel
from text_encoder import PretrainedTextEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Optional, Union
def setup_qwen(pipe,
qwen_path='Qwen/Qwen2.5-0.5B',
device=None,
dtype=torch.bfloat16):
pipe.qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_path)
qwen = AutoModelForCausalLM.from_pretrained(qwen_path,
device_map=device,
torch_dtype=dtype)
pipe.qwen_model = qwen.model
return pipe
class FluxQwenPipeline(FluxPipeline):
def _get_t5_prompt_embeds(self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None):
qwen_out = self.encode_qwen(prompt, max_sequence_length, device)
inputs = self.tokenizer_2(prompt,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=max_sequence_length)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
output = encoder(qwen_out, encoder.shared(input_ids), max_length=max_sequence_length)
return output * attention_mask.unsqueeze(-1)
def encode_qwen(self, prompt, max_sequence_length=256, device=None):
inputs = self.qwen_tokenizer(prompt,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=max_sequence_length)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
output = self.qwen_model(input_ids=input_ids,
attention_mask=attention_mask)
return output.last_hidden_state
if __name__ == '__main__':
encoder = PretrainedTextEncoder.from_pretrained('twodgirl/flux-qwen-neutered',
device_map='cuda',
torch_dtype=torch.bfloat16)
pipe = FluxQwenPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
text_encoder_2=None,
torch_dtype=torch.bfloat16)
setup_qwen(pipe, device='cuda')
pipe.enable_model_cpu_offload()
image = pipe('a black cat wearing a Pikachu cosplay').images[0]
image.save('cat.png')
Disclaimer
Use of this code and the model requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.
- Downloads last month
- 14
Model tree for twodgirl/flux-qwen-neutered
Base model
black-forest-labs/FLUX.1-dev