bipin commited on
Commit
cae4936
1 Parent(s): 52cd6f3

added files

Browse files
Files changed (4) hide show
  1. app.py +43 -0
  2. gpt2_story_gen.py +11 -0
  3. prefix_clip.py +280 -0
  4. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from prefix_clip import download_pretrained_model, generate_caption
4
+ from gpt2_story_gen import generate_story
5
+
6
+
7
+ def main(pil_image, genre, model="Conceptual", use_beam_search=True):
8
+ model_file = "pretrained_weights.pt"
9
+
10
+ download_pretrained_model(model.lower(), file_to_save=model_file)
11
+
12
+ image_caption = generate_caption(
13
+ model_path=model_file,
14
+ pil_image=pil_image,
15
+ use_beam_search=use_beam_search,
16
+ )
17
+ story = generate_story(image_caption, genre.lower())
18
+ return story
19
+
20
+
21
+ if __name__ == "__main__":
22
+ interface = gr.Interface(
23
+ main,
24
+ title="image2story",
25
+ inputs=[
26
+ gr.inputs.Image(type="pil", source="upload", label="Input"),
27
+ gr.inputs.Dropdown(
28
+ type="value",
29
+ label="Story genre",
30
+ choices=[
31
+ "superhero",
32
+ "action",
33
+ "drama",
34
+ "horror",
35
+ "thriller",
36
+ "sci_fi",
37
+ ],
38
+ ),
39
+ ],
40
+ outputs=gr.outputs.Textbox(label="Generated story"),
41
+ enable_queue=True,
42
+ )
43
+ interface.launch()
gpt2_story_gen.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+
4
+ def generate_story(image_caption, genre):
5
+ story_gen = pipeline("text-generation", "pranavpsv/genre-story-generator-v2")
6
+
7
+ input = f"<BOS> <{genre}> {image_caption}"
8
+ story = story_gen(input)[0]["generated_text"]
9
+ story = f"{story.strip(input)}"
10
+
11
+ return story
prefix_clip.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import os
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as nnf
7
+ import sys
8
+ import gdown
9
+ from typing import Tuple, List, Union, Optional
10
+ from transformers import (
11
+ GPT2Tokenizer,
12
+ GPT2LMHeadModel,
13
+ AdamW,
14
+ get_linear_schedule_with_warmup,
15
+ )
16
+ from tqdm import tqdm, trange
17
+ from google.colab import files
18
+ import skimage.io as io
19
+ import PIL.Image
20
+ from IPython.display import Image
21
+
22
+
23
+ N = type(None)
24
+ V = np.array
25
+ ARRAY = np.ndarray
26
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
27
+ VS = Union[Tuple[V, ...], List[V]]
28
+ VN = Union[V, N]
29
+ VNS = Union[VS, N]
30
+ T = torch.Tensor
31
+ TS = Union[Tuple[T, ...], List[T]]
32
+ TN = Optional[T]
33
+ TNS = Union[Tuple[TN, ...], List[TN]]
34
+ TSN = Optional[TS]
35
+ TA = Union[T, ARRAY]
36
+
37
+ D = torch.device
38
+ CPU = torch.device("cpu")
39
+
40
+
41
+ def download_pretrained_model(model, file_to_save):
42
+ conceptual_wt = "14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT"
43
+ coco_wt = "1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX"
44
+
45
+ # download pretrained weights
46
+ if model == "coco":
47
+ url = f"https://drive.google.com/uc?id={coco_wt}"
48
+ elif model == "conceptual":
49
+ url = f"https://drive.google.com/uc?id={conceptual_wt}"
50
+ gdown.download(url, file_to_save, quiet=False)
51
+
52
+
53
+ class MLP(nn.Module):
54
+ def forward(self, x: T) -> T:
55
+ return self.model(x)
56
+
57
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
58
+ super(MLP, self).__init__()
59
+ layers = []
60
+ for i in range(len(sizes) - 1):
61
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
62
+ if i < len(sizes) - 2:
63
+ layers.append(act())
64
+ self.model = nn.Sequential(*layers)
65
+
66
+
67
+ class ClipCaptionModel(nn.Module):
68
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
69
+ return torch.zeros(
70
+ batch_size, self.prefix_length, dtype=torch.int64, device=device
71
+ )
72
+
73
+ def forward(
74
+ self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
75
+ ):
76
+ embedding_text = self.gpt.transformer.wte(tokens)
77
+ prefix_projections = self.clip_project(prefix).view(
78
+ -1, self.prefix_length, self.gpt_embedding_size
79
+ )
80
+ # print(embedding_text.size()) #torch.Size([5, 67, 768])
81
+ # print(prefix_projections.size()) #torch.Size([5, 1, 768])
82
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
83
+ if labels is not None:
84
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
85
+ labels = torch.cat((dummy_token, tokens), dim=1)
86
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
87
+ return out
88
+
89
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
90
+ super(ClipCaptionModel, self).__init__()
91
+ self.prefix_length = prefix_length
92
+ self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
93
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
94
+ if prefix_length > 10: # not enough memory
95
+ self.clip_project = nn.Linear(
96
+ prefix_size, self.gpt_embedding_size * prefix_length
97
+ )
98
+ else:
99
+ self.clip_project = MLP(
100
+ (
101
+ prefix_size,
102
+ (self.gpt_embedding_size * prefix_length) // 2,
103
+ self.gpt_embedding_size * prefix_length,
104
+ )
105
+ )
106
+
107
+
108
+ class ClipCaptionPrefix(ClipCaptionModel):
109
+ def parameters(self, recurse: bool = True):
110
+ return self.clip_project.parameters()
111
+
112
+ def train(self, mode: bool = True):
113
+ super(ClipCaptionPrefix, self).train(mode)
114
+ self.gpt.eval()
115
+ return self
116
+
117
+
118
+ def generate_beam(
119
+ model,
120
+ tokenizer,
121
+ beam_size: int = 5,
122
+ prompt=None,
123
+ embed=None,
124
+ entry_length=67,
125
+ temperature=1.0,
126
+ stop_token: str = ".",
127
+ ):
128
+
129
+ model.eval()
130
+ stop_token_index = tokenizer.encode(stop_token)[0]
131
+ tokens = None
132
+ scores = None
133
+ device = next(model.parameters()).device
134
+ seq_lengths = torch.ones(beam_size, device=device)
135
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
136
+ with torch.no_grad():
137
+ if embed is not None:
138
+ generated = embed
139
+ else:
140
+ if tokens is None:
141
+ tokens = torch.tensor(tokenizer.encode(prompt))
142
+ tokens = tokens.unsqueeze(0).to(device)
143
+ generated = model.gpt.transformer.wte(tokens)
144
+ for i in range(entry_length):
145
+ outputs = model.gpt(inputs_embeds=generated)
146
+ logits = outputs.logits
147
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
148
+ logits = logits.softmax(-1).log()
149
+ if scores is None:
150
+ scores, next_tokens = logits.topk(beam_size, -1)
151
+ generated = generated.expand(beam_size, *generated.shape[1:])
152
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
153
+ if tokens is None:
154
+ tokens = next_tokens
155
+ else:
156
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
157
+ tokens = torch.cat((tokens, next_tokens), dim=1)
158
+ else:
159
+ logits[is_stopped] = -float(np.inf)
160
+ logits[is_stopped, 0] = 0
161
+ scores_sum = scores[:, None] + logits
162
+ seq_lengths[~is_stopped] += 1
163
+ scores_sum_average = scores_sum / seq_lengths[:, None]
164
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
165
+ beam_size, -1
166
+ )
167
+ next_tokens_source = next_tokens // scores_sum.shape[1]
168
+ seq_lengths = seq_lengths[next_tokens_source]
169
+ next_tokens = next_tokens % scores_sum.shape[1]
170
+ next_tokens = next_tokens.unsqueeze(1)
171
+ tokens = tokens[next_tokens_source]
172
+ tokens = torch.cat((tokens, next_tokens), dim=1)
173
+ generated = generated[next_tokens_source]
174
+ scores = scores_sum_average * seq_lengths
175
+ is_stopped = is_stopped[next_tokens_source]
176
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
177
+ generated.shape[0], 1, -1
178
+ )
179
+ generated = torch.cat((generated, next_token_embed), dim=1)
180
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
181
+ if is_stopped.all():
182
+ break
183
+ scores = scores / seq_lengths
184
+ output_list = tokens.cpu().numpy()
185
+ output_texts = [
186
+ tokenizer.decode(output[: int(length)])
187
+ for output, length in zip(output_list, seq_lengths)
188
+ ]
189
+ order = scores.argsort(descending=True)
190
+ output_texts = [output_texts[i] for i in order]
191
+ return output_texts
192
+
193
+
194
+ def generate2(
195
+ model,
196
+ tokenizer,
197
+ tokens=None,
198
+ prompt=None,
199
+ embed=None,
200
+ entry_count=1,
201
+ entry_length=67, # maximum number of words
202
+ top_p=0.8,
203
+ temperature=1.0,
204
+ stop_token: str = ".",
205
+ ):
206
+ model.eval()
207
+ generated_num = 0
208
+ generated_list = []
209
+ stop_token_index = tokenizer.encode(stop_token)[0]
210
+ filter_value = -float("Inf")
211
+ device = next(model.parameters()).device
212
+
213
+ with torch.no_grad():
214
+
215
+ for entry_idx in trange(entry_count):
216
+ if embed is not None:
217
+ generated = embed
218
+ else:
219
+ if tokens is None:
220
+ tokens = torch.tensor(tokenizer.encode(prompt))
221
+ tokens = tokens.unsqueeze(0).to(device)
222
+
223
+ generated = model.gpt.transformer.wte(tokens)
224
+
225
+ for i in range(entry_length):
226
+
227
+ outputs = model.gpt(inputs_embeds=generated)
228
+ logits = outputs.logits
229
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
230
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
231
+ cumulative_probs = torch.cumsum(
232
+ nnf.softmax(sorted_logits, dim=-1), dim=-1
233
+ )
234
+ sorted_indices_to_remove = cumulative_probs > top_p
235
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
236
+ ..., :-1
237
+ ].clone()
238
+ sorted_indices_to_remove[..., 0] = 0
239
+
240
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
241
+ logits[:, indices_to_remove] = filter_value
242
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
243
+ next_token_embed = model.gpt.transformer.wte(next_token)
244
+ if tokens is None:
245
+ tokens = next_token
246
+ else:
247
+ tokens = torch.cat((tokens, next_token), dim=1)
248
+ generated = torch.cat((generated, next_token_embed), dim=1)
249
+ if stop_token_index == next_token.item():
250
+ break
251
+
252
+ output_list = list(tokens.squeeze().cpu().numpy())
253
+ output_text = tokenizer.decode(output_list)
254
+ generated_list.append(output_text)
255
+
256
+ return generated_list[0]
257
+
258
+
259
+ def generate_caption(model_path, pil_image, use_beam_search):
260
+ device = "cuda" if torch.cuda.is_available() else "cpu"
261
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
262
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
263
+
264
+ prefix_length = 10
265
+
266
+ model = ClipCaptionModel(prefix_length)
267
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
268
+ model = model.eval()
269
+ model = model.to(device)
270
+
271
+ image = preprocess(pil_image).unsqueeze(0).to(device)
272
+ with torch.no_grad():
273
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
274
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
275
+ if use_beam_search:
276
+ image_caption = generate_beam(model, tokenizer, embed=prefix_embed)[0]
277
+ else:
278
+ image_caption = generate2(model, tokenizer, embed=prefix_embed)
279
+
280
+ return image_caption
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ gdown
4
+ transformers
5
+ tqdm
6
+ Pillow
7
+ scikit-image
8
+ git+https://github.com/openai/CLIP.git