Spaces:
Running
on
Zero
Running
on
Zero
thomasgauthier
commited on
Commit
•
bdf9962
1
Parent(s):
fa851d1
first commit
Browse files- app.py +14 -5
- gradio_interface.py +32 -0
- image_generator.py +117 -0
- model_loader.py +14 -0
- requirements.txt +7 -0
app.py
CHANGED
@@ -1,7 +1,16 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import spaces
|
3 |
+
from model_loader import load_model_and_processor
|
4 |
+
from image_generator import process_and_generate
|
5 |
+
from gradio_interface import create_gradio_interface
|
6 |
|
7 |
+
if __name__ == "__main__":
|
8 |
+
# Set the model path
|
9 |
+
model_path = "deepseek-ai/Janus-1.3B"
|
10 |
|
11 |
+
# Load the model and processor
|
12 |
+
vl_gpt, vl_chat_processor = load_model_and_processor(model_path)
|
13 |
+
|
14 |
+
# Create and launch the Gradio interface
|
15 |
+
demo = create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate)
|
16 |
+
demo.launch(allowed_paths=["/"])
|
gradio_interface.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
def create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate):
|
5 |
+
def gradio_process_and_generate(input_image, prompt, num_images, cfg_weight):
|
6 |
+
return process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images, cfg_weight)
|
7 |
+
|
8 |
+
explanation = """Janus 1.3B uses a differerent visual encoder for understanding and generation.
|
9 |
+
|
10 |
+
![Janus Model Architecture](file/images/janus_architecture.svg)
|
11 |
+
|
12 |
+
Here, by feeding the model an image and then asking it to generate that same image, we visualize the model's ability to translate input (understanding) embedding space to generative embedding space."""
|
13 |
+
|
14 |
+
with gr.Blocks() as demo:
|
15 |
+
gr.Markdown("# How Janus-1.3B sees itself")
|
16 |
+
|
17 |
+
with gr.Row():
|
18 |
+
input_image = gr.Image(type="filepath", label="Input Image")
|
19 |
+
output_images = gr.Gallery(label="Generated Images", columns=2, rows=2)
|
20 |
+
prompt = gr.Textbox(label="Prompt", value="Exactly what is shown in the image.")
|
21 |
+
num_images = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="Number of Images to Generate")
|
22 |
+
cfg_weight = gr.Slider(minimum=1, maximum=10, value=5, step=0.1, label="CFG Weight")
|
23 |
+
generate_btn = gr.Button("Generate", variant="primary", size="lg")
|
24 |
+
|
25 |
+
generate_btn.click(
|
26 |
+
fn=gradio_process_and_generate,
|
27 |
+
inputs=[input_image, prompt, num_images, cfg_weight],
|
28 |
+
outputs=output_images
|
29 |
+
)
|
30 |
+
gr.Markdown(explanation)
|
31 |
+
|
32 |
+
return demo
|
image_generator.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import PIL.Image
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from janus.utils.io import load_pil_images
|
6 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
|
10 |
+
def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
|
11 |
+
uncond_input_ids = torch.full((1, input_embeds.shape[1]),
|
12 |
+
vl_chat_processor.pad_id,
|
13 |
+
dtype=torch.long,
|
14 |
+
device=input_embeds.device)
|
15 |
+
uncond_input_ids[:, 0] = input_embeds.shape[1] - 1
|
16 |
+
uncond_input_ids[:, -1] = vl_chat_processor.tokenizer.eos_token_id
|
17 |
+
|
18 |
+
uncond_input_embeds = mmgpt.language_model.get_input_embeddings()(uncond_input_ids)
|
19 |
+
uncond_input_embeds[:, -1, :] = input_embeds[:, -1, :]
|
20 |
+
|
21 |
+
cond_input_embeds = input_embeds.repeat(batch_size, 1, 1)
|
22 |
+
uncond_input_embeds = uncond_input_embeds.repeat(batch_size, 1, 1)
|
23 |
+
|
24 |
+
combined_input_embeds = torch.stack([cond_input_embeds, uncond_input_embeds], dim=1)
|
25 |
+
combined_input_embeds = combined_input_embeds.view(batch_size * 2, -1, input_embeds.shape[-1])
|
26 |
+
|
27 |
+
return combined_input_embeds
|
28 |
+
|
29 |
+
@spaces.GPU
|
30 |
+
@torch.inference_mode()
|
31 |
+
def generate(
|
32 |
+
mmgpt: MultiModalityCausalLM,
|
33 |
+
vl_chat_processor: VLChatProcessor,
|
34 |
+
inputs_embeds,
|
35 |
+
temperature: float = 1,
|
36 |
+
parallel_size: int = 1,
|
37 |
+
cfg_weight: float = 5,
|
38 |
+
image_token_num_per_image: int = 576,
|
39 |
+
img_size: int = 384,
|
40 |
+
patch_size: int = 16,
|
41 |
+
):
|
42 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
43 |
+
|
44 |
+
inputs_embeds = prepare_classifier_free_guidance_input(inputs_embeds, vl_chat_processor, mmgpt, parallel_size)
|
45 |
+
|
46 |
+
for i in range(image_token_num_per_image):
|
47 |
+
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
|
48 |
+
hidden_states = outputs.last_hidden_state
|
49 |
+
|
50 |
+
logits = mmgpt.gen_head(hidden_states[:, -1, :])
|
51 |
+
|
52 |
+
logit_cond = logits[0::2, :]
|
53 |
+
logit_uncond = logits[1::2, :]
|
54 |
+
|
55 |
+
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
56 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
57 |
+
|
58 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
59 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
60 |
+
|
61 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
62 |
+
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
|
63 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
64 |
+
|
65 |
+
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
|
66 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
67 |
+
|
68 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
69 |
+
|
70 |
+
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
71 |
+
visual_img[:, :, :] = dec
|
72 |
+
|
73 |
+
generated_images = []
|
74 |
+
for i in range(parallel_size):
|
75 |
+
generated_images.append(PIL.Image.fromarray(visual_img[i]))
|
76 |
+
|
77 |
+
return generated_images
|
78 |
+
|
79 |
+
@lru_cache(maxsize=1)
|
80 |
+
def get_start_tag_embed(vl_gpt, vl_chat_processor):
|
81 |
+
with torch.no_grad():
|
82 |
+
return vl_gpt.language_model.get_input_embeddings()(
|
83 |
+
vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
|
84 |
+
)
|
85 |
+
|
86 |
+
def process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images=4, cfg_weight=5):
|
87 |
+
start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)
|
88 |
+
|
89 |
+
nl = '\n'
|
90 |
+
conversation = [
|
91 |
+
{
|
92 |
+
"role": "User",
|
93 |
+
"content": f"<image_placeholder>{nl + prompt if prompt else ''}",
|
94 |
+
"images": [input_image],
|
95 |
+
},
|
96 |
+
{"role": "Assistant", "content": ""},
|
97 |
+
]
|
98 |
+
|
99 |
+
pil_images = load_pil_images(conversation)
|
100 |
+
prepare_inputs = vl_chat_processor(
|
101 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
102 |
+
).to(vl_gpt.device)
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
106 |
+
|
107 |
+
inputs_embeds = torch.cat((inputs_embeds, start_tag_embed), dim=1)
|
108 |
+
|
109 |
+
generated_images = generate(
|
110 |
+
vl_gpt,
|
111 |
+
vl_chat_processor,
|
112 |
+
inputs_embeds,
|
113 |
+
parallel_size=num_images,
|
114 |
+
cfg_weight=cfg_weight
|
115 |
+
)
|
116 |
+
|
117 |
+
return generated_images
|
model_loader.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM
|
3 |
+
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
4 |
+
|
5 |
+
def load_model_and_processor(model_path):
|
6 |
+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
7 |
+
tokenizer = vl_chat_processor.tokenizer
|
8 |
+
|
9 |
+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
10 |
+
model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
|
11 |
+
)
|
12 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
13 |
+
|
14 |
+
return vl_gpt, vl_chat_processor
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
Pillow
|
4 |
+
gradio
|
5 |
+
janus @ git+https://github.com/deepseek-ai/Janus
|
6 |
+
transformers
|
7 |
+
spaces
|