File size: 5,823 Bytes
e5fe5b6 75a7f7d e5fe5b6 e6b50bb e5fe5b6 e10f0a4 32177db e5fe5b6 32177db e5fe5b6 e10f0a4 e5fe5b6 e10f0a4 e5fe5b6 e10f0a4 e5fe5b6 c05f545 e5fe5b6 588ad50 e5fe5b6 588ad50 e5fe5b6 32177db e5fe5b6 e6b50bb c05f545 32177db e5fe5b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import spaces # Import spaces for ZeroGPU compatibility
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
import numpy as np
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
import torch
def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 1,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = processor.pad_id
inputs_embeds = model.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
outputs = model.language_model.model(inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = model.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = model.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=1):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
@spaces.GPU(duration=120) # Specify a duration to avoid timeout
def generate_image(prompt,
width,
height,
guidance,
seed):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
with torch.no_grad():
if seed > -1:
generator = torch.Generator('cpu').manual_seed(seed)
else:
generator = None
messages = [{'role': 'User', 'content': prompt},
{'role': 'Assistant', 'content': ''}]
text = processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
sft_format=processor.sft_format,
system_prompt='')
text = text + processor.image_start_tag
input_ids = torch.LongTensor(processor.tokenizer.encode(text))
output, patches = generate(input_ids,
width // 16 * 16,
height // 16 * 16,
cfg_weight=guidance)
images = unpack(patches,
width // 16 * 16,
height // 16 * 16)
return Image.fromarray(images[0]), seed, ''
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label='Prompt', value='portrait, color, cinematic')
width = gr.Slider(64, 1536, 384, step=16, label='Width')
height = gr.Slider(64, 1536, 384, step=16, label='Height')
guidance = gr.Slider(1.0, 10.0, 5, step=0.1, label='Guidance')
seed = gr.Number(-1, precision=0, label='Seed (-1 for random)')
generate_btn = gr.Button('Generate')
with gr.Column():
output_image = gr.Image(label='Generated Image')
seed_output = gr.Textbox(label='Used Seed')
intermediate_output = gr.Gallery(label='Output', elem_id='gallery', visible=False)
prompt.submit(
fn=generate_image,
inputs=[prompt, width, height, guidance, seed],
outputs=[output_image, seed_output, intermediate_output],
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, width, height, guidance, seed],
outputs=[output_image, seed_output, intermediate_output],
)
if __name__ == '__main__':
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = 'deepseek-ai/Janus-1.3B'
processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
model = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
model = model.to(torch.bfloat16).cuda()
else:
model = model.to(torch.float16)
demo.launch()
|