File size: 5,823 Bytes
e5fe5b6
75a7f7d
e5fe5b6
 
 
 
e6b50bb
e5fe5b6
 
 
 
 
 
 
 
 
 
e10f0a4
 
 
32177db
e5fe5b6
 
 
 
 
32177db
e5fe5b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e10f0a4
e5fe5b6
 
 
 
 
e10f0a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5fe5b6
e10f0a4
e5fe5b6
 
 
 
 
c05f545
 
e5fe5b6
 
 
 
 
 
 
 
 
 
 
 
588ad50
e5fe5b6
 
 
 
588ad50
e5fe5b6
 
 
 
32177db
e5fe5b6
 
 
e6b50bb
c05f545
 
 
 
 
32177db
 
 
 
e5fe5b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import spaces  # Import spaces for ZeroGPU compatibility
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
import numpy as np
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
import torch

def generate(input_ids,
             width,
             height,
             temperature: float = 1,
             parallel_size: int = 1,
             cfg_weight: float = 5,
             image_token_num_per_image: int = 576,
             patch_size: int = 16):
    # Clear CUDA cache before generating
    torch.cuda.empty_cache()
    
    tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
    for i in range(parallel_size * 2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = processor.pad_id
    inputs_embeds = model.language_model.get_input_embeddings()(tokens)
    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)

    pkv = None
    for i in range(image_token_num_per_image):
        outputs = model.language_model.model(inputs_embeds=inputs_embeds,
                                             use_cache=True,
                                             past_key_values=pkv)
        pkv = outputs.past_key_values
        hidden_states = outputs.last_hidden_state
        logits = model.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)
        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = model.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)
    patches = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
                                                 shape=[parallel_size, 8, width // patch_size, height // patch_size])

    return generated_tokens.to(dtype=torch.int), patches

def unpack(dec, width, height, parallel_size=1):
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    return visual_img

@torch.inference_mode()
@spaces.GPU(duration=120)  # Specify a duration to avoid timeout
def generate_image(prompt,
                   width,
                   height,
                   guidance,
                   seed):
    # Clear CUDA cache and avoid tracking gradients
    torch.cuda.empty_cache()
    
    with torch.no_grad():
        if seed > -1:
            generator = torch.Generator('cpu').manual_seed(seed)
        else:
            generator = None
        messages = [{'role': 'User', 'content': prompt},
                    {'role': 'Assistant', 'content': ''}]
        text = processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
                                                                   sft_format=processor.sft_format,
                                                                   system_prompt='')
        text = text + processor.image_start_tag
        input_ids = torch.LongTensor(processor.tokenizer.encode(text))
        output, patches = generate(input_ids,
                                   width // 16 * 16,
                                   height // 16 * 16,
                                   cfg_weight=guidance)
        images = unpack(patches,
                        width // 16 * 16,
                        height // 16 * 16)

        return Image.fromarray(images[0]), seed, ''

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label='Prompt', value='portrait, color, cinematic')
            width = gr.Slider(64, 1536, 384, step=16, label='Width')
            height = gr.Slider(64, 1536, 384, step=16, label='Height')
            guidance = gr.Slider(1.0, 10.0, 5, step=0.1, label='Guidance')
            seed = gr.Number(-1, precision=0, label='Seed (-1 for random)')

            generate_btn = gr.Button('Generate')

        with gr.Column():
            output_image = gr.Image(label='Generated Image')
            seed_output = gr.Textbox(label='Used Seed')
            intermediate_output = gr.Gallery(label='Output', elem_id='gallery', visible=False)

        prompt.submit(
            fn=generate_image,
            inputs=[prompt, width, height, guidance, seed],
            outputs=[output_image, seed_output, intermediate_output],
        )
        generate_btn.click(
            fn=generate_image,
            inputs=[prompt, width, height, guidance, seed],
            outputs=[output_image, seed_output, intermediate_output],
        )

if __name__ == '__main__':
    cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = 'deepseek-ai/Janus-1.3B'
    processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
    tokenizer = processor.tokenizer
    config = AutoConfig.from_pretrained(model_path)
    language_config = config.language_config
    language_config._attn_implementation = 'eager'
    model = AutoModelForCausalLM.from_pretrained(model_path,
                                                 language_config=language_config,
                                                 trust_remote_code=True)
    if torch.cuda.is_available():
        model = model.to(torch.bfloat16).cuda()
    else:
        model = model.to(torch.float16)
    demo.launch()