suryabbrj commited on
Commit
e3d7581
1 Parent(s): 412a1c6

made the front end to accept an image file, from the local directory

Browse files
Files changed (1) hide show
  1. app.py +288 -4
app.py CHANGED
@@ -1,7 +1,291 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import PIL.Image
3
+ import skimage.io as io
4
+ from tqdm import tqdm, trange
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
6
+ from typing import Tuple, List, Union, Optional
7
+ import sys
8
+ import torch.nn.functional as nnf
9
+ import torch
10
+ import numpy as np
11
+ from torch import nn
12
+ import clip
13
+ from huggingface_hub import hf_hub_download
14
+ import os
15
 
16
+ conceptual_weight = hf_hub_download(
17
+ repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
18
+ coco_weight = hf_hub_download(
19
+ repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")
20
 
21
+ N = type(None)
22
+ V = np.array
23
+ ARRAY = np.ndarray
24
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
25
+ VS = Union[Tuple[V, ...], List[V]]
26
+ VN = Union[V, N]
27
+ VNS = Union[VS, N]
28
+ T = torch.Tensor
29
+ TS = Union[Tuple[T, ...], List[T]]
30
+ TN = Optional[T]
31
+ TNS = Union[Tuple[TN, ...], List[TN]]
32
+ TSN = Optional[TS]
33
+ TA = Union[T, ARRAY]
34
+
35
+
36
+ D = torch.device
37
+ CPU = torch.device('cpu')
38
+
39
+
40
+ def get_device(device_id: int) -> D:
41
+ if not torch.cuda.is_available():
42
+ return CPU
43
+ device_id = min(torch.cuda.device_count() - 1, device_id)
44
+ return torch.device(f'cuda:{device_id}')
45
+
46
+
47
+ CUDA = get_device
48
+
49
+
50
+ class MLP(nn.Module):
51
+
52
+ def forward(self, x: T) -> T:
53
+ return self.model(x)
54
+
55
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
56
+ super(MLP, self).__init__()
57
+ layers = []
58
+ for i in range(len(sizes) - 1):
59
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
60
+ if i < len(sizes) - 2:
61
+ layers.append(act())
62
+ self.model = nn.Sequential(*layers)
63
+
64
+
65
+ class ClipCaptionModel(nn.Module):
66
+
67
+ # @functools.lru_cache #FIXME
68
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
69
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
70
+
71
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
72
+ embedding_text = self.gpt.transformer.wte(tokens)
73
+ prefix_projections = self.clip_project(
74
+ prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
75
+ # print(embedding_text.size()) #torch.Size([5, 67, 768])
76
+ # print(prefix_projections.size()) #torch.Size([5, 1, 768])
77
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
78
+ if labels is not None:
79
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
80
+ labels = torch.cat((dummy_token, tokens), dim=1)
81
+ out = self.gpt(inputs_embeds=embedding_cat,
82
+ labels=labels, attention_mask=mask)
83
+ return out
84
+
85
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
86
+ super(ClipCaptionModel, self).__init__()
87
+ self.prefix_length = prefix_length
88
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
89
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
90
+ if prefix_length > 10: # not enough memory
91
+ self.clip_project = nn.Linear(
92
+ prefix_size, self.gpt_embedding_size * prefix_length)
93
+ else:
94
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size *
95
+ prefix_length) // 2, self.gpt_embedding_size * prefix_length))
96
+
97
+
98
+ class ClipCaptionPrefix(ClipCaptionModel):
99
+
100
+ def parameters(self, recurse: bool = True):
101
+ return self.clip_project.parameters()
102
+
103
+ def train(self, mode: bool = True):
104
+ super(ClipCaptionPrefix, self).train(mode)
105
+ self.gpt.eval()
106
+ return self
107
+
108
+
109
+ # @title Caption prediction
110
+
111
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
112
+ entry_length=67, temperature=1., stop_token: str = '.'):
113
+
114
+ model.eval()
115
+ stop_token_index = tokenizer.encode(stop_token)[0]
116
+ tokens = None
117
+ scores = None
118
+ device = next(model.parameters()).device
119
+ seq_lengths = torch.ones(beam_size, device=device)
120
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
121
+ with torch.no_grad():
122
+ if embed is not None:
123
+ generated = embed
124
+ else:
125
+ if tokens is None:
126
+ tokens = torch.tensor(tokenizer.encode(prompt))
127
+ tokens = tokens.unsqueeze(0).to(device)
128
+ generated = model.gpt.transformer.wte(tokens)
129
+ for i in range(entry_length):
130
+ outputs = model.gpt(inputs_embeds=generated)
131
+ logits = outputs.logits
132
+ logits = logits[:, -1, :] / \
133
+ (temperature if temperature > 0 else 1.0)
134
+ logits = logits.softmax(-1).log()
135
+ if scores is None:
136
+ scores, next_tokens = logits.topk(beam_size, -1)
137
+ generated = generated.expand(beam_size, *generated.shape[1:])
138
+ next_tokens, scores = next_tokens.permute(
139
+ 1, 0), scores.squeeze(0)
140
+ if tokens is None:
141
+ tokens = next_tokens
142
+ else:
143
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
144
+ tokens = torch.cat((tokens, next_tokens), dim=1)
145
+ else:
146
+ logits[is_stopped] = -float(np.inf)
147
+ logits[is_stopped, 0] = 0
148
+ scores_sum = scores[:, None] + logits
149
+ seq_lengths[~is_stopped] += 1
150
+ scores_sum_average = scores_sum / seq_lengths[:, None]
151
+ scores_sum_average, next_tokens = scores_sum_average.view(
152
+ -1).topk(beam_size, -1)
153
+ next_tokens_source = next_tokens // scores_sum.shape[1]
154
+ seq_lengths = seq_lengths[next_tokens_source]
155
+ next_tokens = next_tokens % scores_sum.shape[1]
156
+ next_tokens = next_tokens.unsqueeze(1)
157
+ tokens = tokens[next_tokens_source]
158
+ tokens = torch.cat((tokens, next_tokens), dim=1)
159
+ generated = generated[next_tokens_source]
160
+ scores = scores_sum_average * seq_lengths
161
+ is_stopped = is_stopped[next_tokens_source]
162
+ next_token_embed = model.gpt.transformer.wte(
163
+ next_tokens.squeeze()).view(generated.shape[0], 1, -1)
164
+ generated = torch.cat((generated, next_token_embed), dim=1)
165
+ is_stopped = is_stopped + \
166
+ next_tokens.eq(stop_token_index).squeeze()
167
+ if is_stopped.all():
168
+ break
169
+ scores = scores / seq_lengths
170
+ output_list = tokens.cpu().numpy()
171
+ output_texts = [tokenizer.decode(output[:int(length)])
172
+ for output, length in zip(output_list, seq_lengths)]
173
+ order = scores.argsort(descending=True)
174
+ output_texts = [output_texts[i] for i in order]
175
+ return output_texts
176
+
177
+
178
+ def generate2(
179
+ model,
180
+ tokenizer,
181
+ tokens=None,
182
+ prompt=None,
183
+ embed=None,
184
+ entry_count=1,
185
+ entry_length=67, # maximum number of words
186
+ top_p=0.8,
187
+ temperature=1.,
188
+ stop_token: str = '.',
189
+ ):
190
+ model.eval()
191
+ generated_num = 0
192
+ generated_list = []
193
+ stop_token_index = tokenizer.encode(stop_token)[0]
194
+ filter_value = -float("Inf")
195
+ device = next(model.parameters()).device
196
+
197
+ with torch.no_grad():
198
+
199
+ for entry_idx in trange(entry_count):
200
+ if embed is not None:
201
+ generated = embed
202
+ else:
203
+ if tokens is None:
204
+ tokens = torch.tensor(tokenizer.encode(prompt))
205
+ tokens = tokens.unsqueeze(0).to(device)
206
+
207
+ generated = model.gpt.transformer.wte(tokens)
208
+
209
+ for i in range(entry_length):
210
+
211
+ outputs = model.gpt(inputs_embeds=generated)
212
+ logits = outputs.logits
213
+ logits = logits[:, -1, :] / \
214
+ (temperature if temperature > 0 else 1.0)
215
+ sorted_logits, sorted_indices = torch.sort(
216
+ logits, descending=True)
217
+ cumulative_probs = torch.cumsum(
218
+ nnf.softmax(sorted_logits, dim=-1), dim=-1)
219
+ sorted_indices_to_remove = cumulative_probs > top_p
220
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
221
+ ..., :-1
222
+ ].clone()
223
+ sorted_indices_to_remove[..., 0] = 0
224
+
225
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
226
+ logits[:, indices_to_remove] = filter_value
227
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
228
+ next_token_embed = model.gpt.transformer.wte(next_token)
229
+ if tokens is None:
230
+ tokens = next_token
231
+ else:
232
+ tokens = torch.cat((tokens, next_token), dim=1)
233
+ generated = torch.cat((generated, next_token_embed), dim=1)
234
+ if stop_token_index == next_token.item():
235
+ break
236
+
237
+ output_list = list(tokens.squeeze().cpu().numpy())
238
+ output_text = tokenizer.decode(output_list)
239
+ generated_list.append(output_text)
240
+
241
+ return generated_list[0]
242
+
243
+
244
+ is_gpu = False
245
+ device = CUDA(0) if is_gpu else "cpu"
246
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
247
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
248
+
249
+
250
+ def inference(img, model_name):
251
+ prefix_length = 10
252
+
253
+ model = ClipCaptionModel(prefix_length)
254
+
255
+ if model_name == "COCO":
256
+ model_path = coco_weight
257
+ else:
258
+ model_path = conceptual_weight
259
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
260
+ model = model.eval()
261
+ device = CUDA(0) if is_gpu else "cpu"
262
+ model = model.to(device)
263
+
264
+ use_beam_search = False
265
+ image = io.imread(img.name)
266
+ pil_image = PIL.Image.fromarray(image)
267
+ image = preprocess(pil_image).unsqueeze(0).to(device)
268
+ with torch.no_grad():
269
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
270
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
271
+ if use_beam_search:
272
+ generated_text_prefix = generate_beam(
273
+ model, tokenizer, embed=prefix_embed)[0]
274
+ else:
275
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
276
+ return generated_text_prefix
277
+
278
+
279
+ title = "ProjectX"
280
+ description = "Front-End Application used for ContentModX engine built using Python. To use it, simply upload your image, or click one of the examples to load them."
281
+ article = "<p style='text-align: center'><a href='https://github.com/suryabbrj/python_ml' target='_blank'>Github Repo</a></p>"
282
+
283
+ gr.Interface(
284
+ inference,
285
+ [gr.inputs.Image(type="filepath", label="Input"), gr.inputs.Radio(choices=[
286
+ "Yes", "No"], type="value", default="COCO", label="would you like to constribute this result to the model training dataset (do this only if the image used is not a personal image, of you or anyone else you know.)")],
287
+ gr.outputs.Textbox(label="Output"),
288
+ title=title,
289
+ description=description,
290
+ article=article,
291
+ ).launch(debug=True, share=True)