|
import gradio as gr |
|
from janus.models import MultiModalityCausalLM, VLChatProcessor |
|
from janus.utils.io import load_pil_images |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def generate(input_ids, |
|
width, |
|
height, |
|
temperature: float = 1, |
|
parallel_size: int = 1, |
|
cfg_weight: float = 5, |
|
image_token_num_per_image: int = 576, |
|
patch_size: int = 16): |
|
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int) |
|
for i in range(parallel_size * 2): |
|
tokens[i, :] = input_ids |
|
if i % 2 != 0: |
|
tokens[i, 1:-1] = processor.pad_id |
|
inputs_embeds = model.language_model.get_input_embeddings()(tokens) |
|
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) |
|
|
|
pkv = None |
|
for i in range(image_token_num_per_image): |
|
outputs = model.language_model.model(inputs_embeds=inputs_embeds, |
|
use_cache=True, |
|
past_key_values=pkv) |
|
pkv = outputs.past_key_values |
|
hidden_states = outputs.last_hidden_state |
|
logits = model.gen_head(hidden_states[:, -1, :]) |
|
logit_cond = logits[0::2, :] |
|
logit_uncond = logits[1::2, :] |
|
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) |
|
probs = torch.softmax(logits / temperature, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
generated_tokens[:, i] = next_token.squeeze(dim=-1) |
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
|
img_embeds = model.prepare_gen_img_embeds(next_token) |
|
inputs_embeds = img_embeds.unsqueeze(dim=1) |
|
patches = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), |
|
shape=[parallel_size, 8, width // patch_size, height // patch_size]) |
|
|
|
return generated_tokens.to(dtype=torch.int), patches |
|
|
|
def unpack(dec, width, height, parallel_size=1): |
|
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) |
|
dec = np.clip((dec + 1) / 2 * 255, 0, 255) |
|
|
|
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) |
|
visual_img[:, :, :] = dec |
|
|
|
return visual_img |
|
|
|
@torch.inference_mode() |
|
def generate_image(prompt, |
|
width, |
|
height, |
|
|
|
guidance, |
|
seed): |
|
if seed > -1: |
|
generator = torch.Generator('cpu').manual_seed(seed) |
|
else: |
|
generator = None |
|
messages = [{'role': 'User', 'content': prompt}, |
|
{'role': 'Assistant', 'content': ''}] |
|
text = processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, |
|
sft_format=processor.sft_format, |
|
system_prompt='') |
|
text = text + processor.image_start_tag |
|
input_ids = torch.LongTensor(processor.tokenizer.encode(prompt)) |
|
output, patches = generate(input_ids, |
|
width // 16 * 16, |
|
height // 16 * 16, |
|
cfg_weight=guidance) |
|
images = unpack(patches, |
|
width // 16 * 16, |
|
height // 16 * 16) |
|
|
|
return Image.fromarray(images[0]), seed, '' |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label='Prompt', value='portrait, color, cinematic') |
|
width = gr.Slider(256, 1536, 896, step=16, label='Width') |
|
height = gr.Slider(256, 1536, 1152, step=16, label='Height') |
|
guidance = gr.Slider(1.0, 10.0, 5, step=0.1, label='Guidance') |
|
seed = gr.Number(-1, precision=0, label='Seed (-1 for random)') |
|
|
|
generate_btn = gr.Button('Generate') |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label='Generated Image') |
|
seed_output = gr.Textbox(label='Used Seed') |
|
intermediate_output = gr.Gallery(label='Output', elem_id='gallery', visible=False) |
|
|
|
prompt.submit( |
|
fn=generate_image, |
|
inputs=[width, height, guidance, seed, prompt], |
|
outputs=[output_image, seed_output, intermediate_output], |
|
) |
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[width, height, guidance, seed, prompt], |
|
outputs=[output_image, seed_output, intermediate_output], |
|
) |
|
|
|
if __name__ == '__main__': |
|
model_path = 'deepseek-ai/Janus-1.3B' |
|
processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) |
|
tokenizer = processor.tokenizer |
|
model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) |
|
model = model.to(torch.bfloat16) |
|
demo.launch() |
|
|