File size: 4,656 Bytes
bdf9962
 
 
 
 
fc91aa0
bdf9962
 
fc91aa0
bdf9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc91aa0
 
 
 
 
 
 
 
bdf9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import PIL.Image
import torch
import numpy as np
from janus.utils.io import load_pil_images
from model_loader import load_model_and_processor
from janus.models import MultiModalityCausalLM, VLChatProcessor
from functools import lru_cache
import spaces

def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
    uncond_input_ids = torch.full((1, input_embeds.shape[1]), 
                                  vl_chat_processor.pad_id, 
                                  dtype=torch.long, 
                                  device=input_embeds.device)
    uncond_input_ids[:, 0] = input_embeds.shape[1] - 1
    uncond_input_ids[:, -1] = vl_chat_processor.tokenizer.eos_token_id
    
    uncond_input_embeds = mmgpt.language_model.get_input_embeddings()(uncond_input_ids)
    uncond_input_embeds[:, -1, :] = input_embeds[:, -1, :]

    cond_input_embeds = input_embeds.repeat(batch_size, 1, 1)
    uncond_input_embeds = uncond_input_embeds.repeat(batch_size, 1, 1)
    
    combined_input_embeds = torch.stack([cond_input_embeds, uncond_input_embeds], dim=1)
    combined_input_embeds = combined_input_embeds.view(batch_size * 2, -1, input_embeds.shape[-1])
    
    return combined_input_embeds

@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    inputs_embeds,
    temperature: float = 1,
    parallel_size: int = 1,
    cfg_weight: float = 5,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
):
    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    inputs_embeds = prepare_classifier_free_guidance_input(inputs_embeds, vl_chat_processor, mmgpt, parallel_size)

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])

        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    generated_images = []
    for i in range(parallel_size):
        generated_images.append(PIL.Image.fromarray(visual_img[i]))

    return generated_images

@lru_cache(maxsize=1)
def get_start_tag_embed(vl_gpt, vl_chat_processor):
    with torch.no_grad():
        return vl_gpt.language_model.get_input_embeddings()(
            vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
        )
    
@spaces.GPU
def process_and_generate(input_image, prompt, num_images=4, cfg_weight=5):
    # Set the model path
    model_path = "deepseek-ai/Janus-1.3B"

    # Load the model and processor
    vl_gpt, vl_chat_processor = load_model_and_processor(model_path)
    
    start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)

    nl = '\n'
    conversation = [
        {
            "role": "User",
            "content": f"<image_placeholder>{nl + prompt if prompt else ''}",
            "images": [input_image],
        },
        {"role": "Assistant", "content": ""},
    ]

    pil_images = load_pil_images(conversation)
    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(vl_gpt.device)

    with torch.no_grad():
        inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

    inputs_embeds = torch.cat((inputs_embeds, start_tag_embed), dim=1)

    generated_images = generate(
        vl_gpt,
        vl_chat_processor,
        inputs_embeds,
        parallel_size=num_images,
        cfg_weight=cfg_weight
    )

    return generated_images