File size: 5,649 Bytes
e5fe5b6 e6b50bb e5fe5b6 32177db e5fe5b6 32177db e5fe5b6 c05f545 e5fe5b6 c05f545 e5fe5b6 588ad50 e5fe5b6 588ad50 e5fe5b6 32177db e5fe5b6 e6b50bb c05f545 32177db e5fe5b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
import numpy as np
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
import torch
##
# Code from deepseek-ai/Janus
# Space from huggingface/twodgirl.
def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 1,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = processor.pad_id
inputs_embeds = model.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
outputs = model.language_model.model(inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = model.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = model.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=1):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
def generate_image(prompt,
width,
height,
# num_steps,
guidance,
seed):
if seed > -1:
generator = torch.Generator('cpu').manual_seed(seed)
else:
generator = None
messages = [{'role': 'User', 'content': prompt},
{'role': 'Assistant', 'content': ''}]
text = processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
sft_format=processor.sft_format,
system_prompt='')
text = text + processor.image_start_tag
input_ids = torch.LongTensor(processor.tokenizer.encode(text))
output, patches = generate(input_ids,
width // 16 * 16,
height // 16 * 16,
cfg_weight=guidance)
images = unpack(patches,
width // 16 * 16,
height // 16 * 16)
return Image.fromarray(images[0]), seed, ''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label='Prompt', value='portrait, color, cinematic')
width = gr.Slider(64, 1536, 384, step=16, label='Width')
height = gr.Slider(64, 1536, 384, step=16, label='Height')
guidance = gr.Slider(1.0, 10.0, 5, step=0.1, label='Guidance')
seed = gr.Number(-1, precision=0, label='Seed (-1 for random)')
generate_btn = gr.Button('Generate')
with gr.Column():
output_image = gr.Image(label='Generated Image')
seed_output = gr.Textbox(label='Used Seed')
intermediate_output = gr.Gallery(label='Output', elem_id='gallery', visible=False)
prompt.submit(
fn=generate_image,
inputs=[prompt, width, height, guidance, seed],
outputs=[output_image, seed_output, intermediate_output],
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, width, height, guidance, seed],
outputs=[output_image, seed_output, intermediate_output],
)
if __name__ == '__main__':
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = 'deepseek-ai/Janus-1.3B'
processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
# model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
model = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
model = model.to(torch.bfloat16).cuda()
else:
model = model.to(torch.float16)
demo.launch()
|