AlexWortega commited on
Commit
28e228f
1 Parent(s): 797767a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rudalle import get_tokenizer, get_vae
3
+ from rudalle.utils import seed_everything
4
+
5
+ import sys
6
+
7
+
8
+ import gradio as gr
9
+
10
+ from PIL import Image
11
+
12
+ device = 'cpu'
13
+ import clip
14
+ import os
15
+ from torch import nn
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as nnf
19
+ import sys
20
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
21
+ from tqdm import tqdm, trange
22
+ import PIL.Image
23
+ from IPython.display import Image
24
+
25
+ import transformers
26
+
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+
29
+ model_path = 'coco_prefix_latest.pt'
30
+
31
+
32
+
33
+ #title Model
34
+
35
+ class MLP(nn.Module):
36
+
37
+ def forward(self, x):
38
+ return self.model(x)
39
+
40
+ def __init__(self, sizes, bias=True, act=nn.Tanh):
41
+ super(MLP, self).__init__()
42
+ layers = []
43
+ for i in range(len(sizes) -1):
44
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
45
+ if i < len(sizes) - 2:
46
+ layers.append(act())
47
+ self.model = nn.Sequential(*layers)
48
+
49
+
50
+ class ClipCaptionModel(nn.Module):
51
+
52
+ #@functools.lru_cache #FIXME
53
+ def get_dummy_token(self, batch_size, device):
54
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
55
+
56
+ def forward(self, tokens, prefix, mask, labels):
57
+ embedding_text = self.gpt.transformer.wte(tokens)
58
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
59
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
60
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
61
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
62
+ if labels is not None:
63
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
64
+ labels = torch.cat((dummy_token, tokens), dim=1)
65
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
66
+ return out
67
+
68
+ def __init__(self, prefix_length, prefix_size: int = 512):
69
+ super(ClipCaptionModel, self).__init__()
70
+ self.prefix_length = prefix_length
71
+
72
+ self.gpt = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
73
+
74
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
75
+ if prefix_length > 10: # not enough memory
76
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
77
+ else:
78
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
79
+
80
+
81
+ class ClipCaptionPrefix(ClipCaptionModel):
82
+
83
+ def parameters(self, recurse: bool = True):
84
+ return self.clip_project.parameters()
85
+
86
+ def train(self, mode: bool = True):
87
+ super(ClipCaptionPrefix, self).train(mode)
88
+ self.gpt.eval()
89
+ return self
90
+
91
+
92
+
93
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
94
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
95
+ prefix_length = 10
96
+ model = ClipCaptionModel(prefix_length)
97
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
98
+ model.to(device)
99
+ def generate2(
100
+ model,
101
+ tokenizer,
102
+ tokens=None,
103
+ prompt=None,
104
+ embed=None,
105
+ entry_count=1,
106
+ entry_length=67,
107
+ top_p=0.98,
108
+ temperature=1.,
109
+ stop_token = '',
110
+ ):
111
+ model.eval()
112
+ generated_num = 0
113
+ generated_list = []
114
+ stop_token_index = tokenizer.encode(stop_token)[0]
115
+ filter_value = -float("Inf")
116
+ device = next(model.parameters()).device
117
+
118
+ with torch.no_grad():
119
+
120
+ for entry_idx in trange(entry_count):
121
+ if embed is not None:
122
+ generated = embed
123
+ else:
124
+ if tokens is None:
125
+ tokens = torch.tensor(tokenizer.encode(prompt))
126
+ tokens = tokens.unsqueeze(0).to(device)
127
+
128
+ generated = model.gpt.transformer.wte(tokens)
129
+
130
+ for i in range(entry_length):
131
+
132
+ outputs = model.gpt(inputs_embeds=generated)
133
+ logits = outputs.logits
134
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
135
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
136
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
137
+ sorted_indices_to_remove = cumulative_probs > top_p
138
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
139
+ ..., :-1
140
+ ].clone()
141
+ sorted_indices_to_remove[..., 0] = 0
142
+
143
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
144
+ logits[:, indices_to_remove] = filter_value
145
+ #
146
+ top_k = 2000
147
+ top_p = 0.98
148
+ #print(logits)
149
+ #next_token = transformers.top_k_top_p_filtering(logits.to(torch.int64).unsqueeze(0), top_k=top_k, top_p=top_p)
150
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
151
+ next_token_embed = model.gpt.transformer.wte(next_token)
152
+
153
+ if tokens is None:
154
+ tokens = next_token
155
+ else:
156
+ tokens = torch.cat((tokens, next_token), dim=1)
157
+ generated = torch.cat((generated, next_token_embed), dim=1)
158
+
159
+ if stop_token_index == next_token.item():
160
+ break
161
+
162
+ output_list = list(tokens.squeeze().cpu().numpy())
163
+ output_text = tokenizer.decode(output_list)
164
+ generated_list.append(output_text)
165
+
166
+ return generated_list[0]
167
+
168
+
169
+
170
+ def _to_caption(pil_image):
171
+
172
+ image = preprocess(pil_image).unsqueeze(0).to(device)
173
+ with torch.no_grad():
174
+
175
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
176
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
177
+
178
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
179
+ return generated_text_prefix
180
+
181
+
182
+
183
+ def classify_image(inp):
184
+ print(type(inp))
185
+ inp = Image.fromarray(inp)
186
+ texts = _to_caption(inp)
187
+
188
+ print(texts)
189
+
190
+
191
+ return texts
192
+
193
+ image = gr.inputs.Image(shape=(128, 128))
194
+ label = gr.outputs.Label(num_top_classes=3)
195
+
196
+
197
+ iface = gr.Interface(fn=classify_image, description="https://github.com/AlexWortega/ruImageCaptioning RuImage Captioning trained for a image2text task to predict food calories by https://t.me/lovedeathtransformers Alex Wortega", inputs=image, outputs="text",examples=[
198
+ ['b9c277a3.jpeg']])
199
+ iface.launch()