gabrielchua commited on
Commit
4e757f3
1 Parent(s): 9f78eb1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM
8
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
9
+ from janus.utils.io import load_pil_images
10
+
11
+ # Specify the path to the model
12
+ model_path = "deepseek-ai/Janus-1.3B"
13
+
14
+ # Load the VLChatProcessor and tokenizer
15
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
16
+ tokenizer = vl_chat_processor.tokenizer
17
+
18
+ # Load the MultiModalityCausalLM model
19
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
20
+ model_path, trust_remote_code=True
21
+ )
22
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
23
+
24
+ @spaces.GPU(duration=120)
25
+ def image_to_latex(image: Image.Image) -> str:
26
+ """
27
+ Convert an uploaded image of a formula into LaTeX code.
28
+ """
29
+ # Define the conversation with the uploaded image
30
+ conversation = [
31
+ {
32
+ "role": "User",
33
+ "content": "<image_placeholder>\nConvert the formula into latex code.",
34
+ "images": [image],
35
+ },
36
+ {"role": "Assistant", "content": ""},
37
+ ]
38
+
39
+ # Load the PIL images from the conversation
40
+ pil_images = load_pil_images(conversation)
41
+
42
+ # Prepare the inputs for the model
43
+ prepare_inputs = vl_chat_processor(
44
+ conversations=conversation, images=pil_images, force_batchify=True
45
+ ).to(vl_gpt.device)
46
+
47
+ # Prepare input embeddings
48
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
49
+
50
+ # Generate the response from the model
51
+ outputs = vl_gpt.language_model.generate(
52
+ inputs_embeds=inputs_embeds,
53
+ attention_mask=prepare_inputs.attention_mask,
54
+ pad_token_id=tokenizer.eos_token_id,
55
+ bos_token_id=tokenizer.bos_token_id,
56
+ eos_token_id=tokenizer.eos_token_id,
57
+ max_new_tokens=512,
58
+ do_sample=False,
59
+ use_cache=True,
60
+ )
61
+
62
+ # Decode the generated tokens to get the answer
63
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
64
+
65
+ return answer
66
+
67
+ @spaces.GPU(duration=120)
68
+ def text_to_image(prompt: str) -> Image.Image:
69
+ """
70
+ Generate an image based on the input text prompt.
71
+ """
72
+ # Define the conversation with the user prompt
73
+ conversation = [
74
+ {
75
+ "role": "User",
76
+ "content": prompt,
77
+ },
78
+ {"role": "Assistant", "content": ""},
79
+ ]
80
+
81
+ # Apply the SFT template to format the prompt
82
+ sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
83
+ conversations=conversation,
84
+ sft_format=vl_chat_processor.sft_format,
85
+ system_prompt="",
86
+ )
87
+ prompt_text = sft_format + vl_chat_processor.image_start_tag
88
+
89
+ # Encode the prompt
90
+ input_ids = vl_chat_processor.tokenizer.encode(prompt_text)
91
+ input_ids = torch.LongTensor(input_ids)
92
+
93
+ # Prepare tokens for generation
94
+ tokens = torch.zeros((2, len(input_ids)), dtype=torch.int).cuda()
95
+ tokens[0, :] = input_ids
96
+ tokens[1, :] = vl_chat_processor.pad_id
97
+
98
+ # Get input embeddings
99
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
100
+
101
+ # Generation parameters
102
+ image_token_num_per_image = 576
103
+ img_size = 384
104
+ patch_size = 16
105
+ cfg_weight = 5
106
+ temperature = 1
107
+
108
+ # Initialize tensor to store generated tokens
109
+ generated_tokens = torch.zeros((1, image_token_num_per_image), dtype=torch.int).cuda()
110
+
111
+ for i in range(image_token_num_per_image):
112
+ if i == 0:
113
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True)
114
+ else:
115
+ outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values)
116
+
117
+ hidden_states = outputs.last_hidden_state
118
+
119
+ # Get logits and apply classifier-free guidance
120
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
121
+ logit_cond = logits[0::2, :]
122
+ logit_uncond = logits[1::2, :]
123
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
124
+
125
+ # Sample the next token
126
+ probs = torch.softmax(logits / temperature, dim=-1)
127
+ next_token = torch.multinomial(probs, num_samples=1)
128
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
129
+
130
+ # Prepare for the next step
131
+ next_token_combined = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
132
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_combined)
133
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
134
+
135
+ # Decode the generated tokens to get the image
136
+ dec = vl_gpt.gen_vision_model.decode_code(
137
+ generated_tokens.to(dtype=torch.int),
138
+ shape=[1, 8, img_size//patch_size, img_size//patch_size]
139
+ )
140
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
141
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
142
+
143
+ # Convert to PIL Image
144
+ visual_img = dec[0]
145
+ image = Image.fromarray(visual_img)
146
+
147
+ return image
148
+
149
+ # Create the Gradio interface
150
+ with gr.Blocks() as demo:
151
+ gr.Markdown(
152
+ """
153
+ # Janus-1.3B Gradio Demo
154
+ This demo showcases two functionalities using the Janus-1.3B model:
155
+ 1. **Image to LaTeX**: Upload an image of a mathematical formula to convert it into LaTeX code.
156
+ 2. **Text to Image**: Enter a descriptive text prompt to generate a corresponding image.
157
+ """
158
+ )
159
+
160
+ with gr.Tab("Image to LaTeX"):
161
+ gr.Markdown("### Convert Formula Image to LaTeX Code")
162
+ with gr.Row():
163
+ with gr.Column():
164
+ image_input = gr.Image(
165
+ type="pil",
166
+ label="Upload Formula Image",
167
+ tool="editor",
168
+ )
169
+ submit_btn = gr.Button("Convert to LaTeX")
170
+ with gr.Column():
171
+ latex_output = gr.Textbox(
172
+ label="LaTeX Code",
173
+ lines=10,
174
+ )
175
+ submit_btn.click(fn=image_to_latex, inputs=image_input, outputs=latex_output)
176
+
177
+ with gr.Tab("Text to Image"):
178
+ gr.Markdown("### Generate Image from Text Prompt")
179
+ with gr.Row():
180
+ with gr.Column():
181
+ prompt_input = gr.Textbox(
182
+ lines=2,
183
+ placeholder="Enter your image description here...",
184
+ label="Text Prompt",
185
+ )
186
+ generate_btn = gr.Button("Generate Image")
187
+ with gr.Column():
188
+ image_output = gr.Image(
189
+ label="Generated Image",
190
+ )
191
+ generate_btn.click(fn=text_to_image, inputs=prompt_input, outputs=image_output)
192
+ )
193
+
194
+ # Launch the Gradio app
195
+ if __name__ == "__main__":
196
+ demo.launch()
197
+