gabrielchua commited on
Commit
574151f
1 Parent(s): a18d449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -38
app.py CHANGED
@@ -12,26 +12,32 @@ from janus.utils.io import load_pil_images
12
  model_path = "deepseek-ai/Janus-1.3B"
13
 
14
  # Load the VLChatProcessor and tokenizer
 
15
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
16
  tokenizer = vl_chat_processor.tokenizer
17
 
18
  # Load the MultiModalityCausalLM model
 
19
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
20
  model_path, trust_remote_code=True
21
  )
22
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
 
 
23
 
24
  @spaces.GPU(duration=120)
25
- def image_to_latex(image: Image.Image) -> str:
26
  """
27
- Convert an uploaded image of a formula into LaTeX code.
 
 
28
  """
29
- # Define the conversation with the uploaded image
30
  conversation = [
31
  {
32
  "role": "User",
33
- "content": "<image_placeholder>\nConvert the formula into latex code.",
34
- "images": [image],
35
  },
36
  {"role": "Assistant", "content": ""},
37
  ]
@@ -42,22 +48,23 @@ def image_to_latex(image: Image.Image) -> str:
42
  # Prepare the inputs for the model
43
  prepare_inputs = vl_chat_processor(
44
  conversations=conversation, images=pil_images, force_batchify=True
45
- ).to(vl_gpt.device)
46
 
47
  # Prepare input embeddings
48
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
49
 
50
  # Generate the response from the model
51
- outputs = vl_gpt.language_model.generate(
52
- inputs_embeds=inputs_embeds,
53
- attention_mask=prepare_inputs.attention_mask,
54
- pad_token_id=tokenizer.eos_token_id,
55
- bos_token_id=tokenizer.bos_token_id,
56
- eos_token_id=tokenizer.eos_token_id,
57
- max_new_tokens=512,
58
- do_sample=False,
59
- use_cache=True,
60
- )
 
61
 
62
  # Decode the generated tokens to get the answer
63
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
@@ -88,12 +95,15 @@ def text_to_image(prompt: str) -> Image.Image:
88
 
89
  # Encode the prompt
90
  input_ids = vl_chat_processor.tokenizer.encode(prompt_text)
91
- input_ids = torch.LongTensor(input_ids)
92
 
93
  # Prepare tokens for generation
94
- tokens = torch.zeros((2, len(input_ids)), dtype=torch.int).cuda()
95
- tokens[0, :] = input_ids
96
- tokens[1, :] = vl_chat_processor.pad_id
 
 
 
97
 
98
  # Get input embeddings
99
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
@@ -106,7 +116,7 @@ def text_to_image(prompt: str) -> Image.Image:
106
  temperature = 1
107
 
108
  # Initialize tensor to store generated tokens
109
- generated_tokens = torch.zeros((1, image_token_num_per_image), dtype=torch.int).cuda()
110
 
111
  for i in range(image_token_num_per_image):
112
  if i == 0:
@@ -128,14 +138,14 @@ def text_to_image(prompt: str) -> Image.Image:
128
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
129
 
130
  # Prepare for the next step
131
- next_token_combined = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
132
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_combined)
133
  inputs_embeds = img_embeds.unsqueeze(dim=1)
134
 
135
  # Decode the generated tokens to get the image
136
  dec = vl_gpt.gen_vision_model.decode_code(
137
  generated_tokens.to(dtype=torch.int),
138
- shape=[1, 8, img_size//patch_size, img_size//patch_size]
139
  )
140
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
141
  dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
@@ -152,30 +162,36 @@ with gr.Blocks() as demo:
152
  """
153
  # Janus-1.3B Gradio Demo
154
  This demo showcases two functionalities using the Janus-1.3B model:
155
- 1. **Image to LaTeX**: Upload an image of a mathematical formula to convert it into LaTeX code.
156
  2. **Text to Image**: Enter a descriptive text prompt to generate a corresponding image.
157
  """
158
  )
159
 
160
- with gr.Tab("Image to LaTeX"):
161
- gr.Markdown("### Convert Formula Image to LaTeX Code")
162
  with gr.Row():
163
  with gr.Column():
164
- image_input = gr.Image(
 
 
 
 
 
165
  type="pil",
166
- label="Upload Formula Image",
167
  tool="editor",
168
  )
169
- submit_btn = gr.Button("Convert to LaTeX")
170
  with gr.Column():
171
- latex_output = gr.Textbox(
172
- label="LaTeX Code",
173
- lines=10,
 
174
  )
175
- submit_btn.click(fn=image_to_latex, inputs=image_input, outputs=latex_output)
176
 
177
  with gr.Tab("Text to Image"):
178
- gr.Markdown("### Generate Image from Text Prompt")
179
  with gr.Row():
180
  with gr.Column():
181
  prompt_input = gr.Textbox(
@@ -189,9 +205,7 @@ with gr.Blocks() as demo:
189
  label="Generated Image",
190
  )
191
  generate_btn.click(fn=text_to_image, inputs=prompt_input, outputs=image_output)
192
- )
193
-
194
  # Launch the Gradio app
195
  if __name__ == "__main__":
196
  demo.launch()
197
-
 
12
  model_path = "deepseek-ai/Janus-1.3B"
13
 
14
  # Load the VLChatProcessor and tokenizer
15
+ print("Loading VLChatProcessor and tokenizer...")
16
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
17
  tokenizer = vl_chat_processor.tokenizer
18
 
19
  # Load the MultiModalityCausalLM model
20
+ print("Loading MultiModalityCausalLM model...")
21
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
22
  model_path, trust_remote_code=True
23
  )
24
+ # Move the model to GPU with bfloat16 precision for efficiency
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ vl_gpt = vl_gpt.to(torch.bfloat16 if device.type == "cuda" else torch.float32).to(device).eval()
27
 
28
  @spaces.GPU(duration=120)
29
+ def text_image_to_text(user_text: str, user_image: Image.Image) -> str:
30
  """
31
+ Generate a textual response based on user-provided text and image.
32
+ This can be used for tasks like converting an image of a formula to LaTeX code
33
+ or generating descriptive captions.
34
  """
35
+ # Define the conversation with user-provided text and image
36
  conversation = [
37
  {
38
  "role": "User",
39
+ "content": user_text,
40
+ "images": [user_image],
41
  },
42
  {"role": "Assistant", "content": ""},
43
  ]
 
48
  # Prepare the inputs for the model
49
  prepare_inputs = vl_chat_processor(
50
  conversations=conversation, images=pil_images, force_batchify=True
51
+ ).to(device)
52
 
53
  # Prepare input embeddings
54
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
55
 
56
  # Generate the response from the model
57
+ with torch.no_grad():
58
+ outputs = vl_gpt.language_model.generate(
59
+ inputs_embeds=inputs_embeds,
60
+ attention_mask=prepare_inputs.attention_mask,
61
+ pad_token_id=tokenizer.eos_token_id,
62
+ bos_token_id=tokenizer.bos_token_id,
63
+ eos_token_id=tokenizer.eos_token_id,
64
+ max_new_tokens=512,
65
+ do_sample=False,
66
+ use_cache=True,
67
+ )
68
 
69
  # Decode the generated tokens to get the answer
70
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
 
95
 
96
  # Encode the prompt
97
  input_ids = vl_chat_processor.tokenizer.encode(prompt_text)
98
+ input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
99
 
100
  # Prepare tokens for generation
101
+ parallel_size = 1 # Adjust based on GPU memory
102
+ tokens = torch.zeros((parallel_size*2, len(input_ids[0])), dtype=torch.int).to(device)
103
+ for i in range(parallel_size*2):
104
+ tokens[i, :] = input_ids
105
+ if i % 2 != 0:
106
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
107
 
108
  # Get input embeddings
109
  inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
 
116
  temperature = 1
117
 
118
  # Initialize tensor to store generated tokens
119
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
120
 
121
  for i in range(image_token_num_per_image):
122
  if i == 0:
 
138
  generated_tokens[:, i] = next_token.squeeze(dim=-1)
139
 
140
  # Prepare for the next step
141
+ next_token_combined = torch.cat([next_token, next_token], dim=0).view(-1)
142
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token_combined)
143
  inputs_embeds = img_embeds.unsqueeze(dim=1)
144
 
145
  # Decode the generated tokens to get the image
146
  dec = vl_gpt.gen_vision_model.decode_code(
147
  generated_tokens.to(dtype=torch.int),
148
+ shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]
149
  )
150
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
151
  dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
 
162
  """
163
  # Janus-1.3B Gradio Demo
164
  This demo showcases two functionalities using the Janus-1.3B model:
165
+ 1. **Text + Image to Text**: Input both text and an image to generate a textual response.
166
  2. **Text to Image**: Enter a descriptive text prompt to generate a corresponding image.
167
  """
168
  )
169
 
170
+ with gr.Tab("Text + Image to Text"):
171
+ gr.Markdown("### Generate Text Based on Input Text and Image")
172
  with gr.Row():
173
  with gr.Column():
174
+ user_text_input = gr.Textbox(
175
+ lines=2,
176
+ placeholder="Enter your instructions or description here...",
177
+ label="Input Text",
178
+ )
179
+ user_image_input = gr.Image(
180
  type="pil",
181
+ label="Upload Image",
182
  tool="editor",
183
  )
184
+ submit_btn = gr.Button("Generate Text")
185
  with gr.Column():
186
+ text_output = gr.Textbox(
187
+ label="Generated Text",
188
+ lines=15,
189
+ interactive=False,
190
  )
191
+ submit_btn.click(fn=text_image_to_text, inputs=[user_text_input, user_image_input], outputs=text_output)
192
 
193
  with gr.Tab("Text to Image"):
194
+ gr.Markdown("### Generate Image Based on Text Prompt")
195
  with gr.Row():
196
  with gr.Column():
197
  prompt_input = gr.Textbox(
 
205
  label="Generated Image",
206
  )
207
  generate_btn.click(fn=text_to_image, inputs=prompt_input, outputs=image_output)
208
+
 
209
  # Launch the Gradio app
210
  if __name__ == "__main__":
211
  demo.launch()