gokaygokay commited on
Commit
e80bc79
1 Parent(s): 7d5cd7f

Upload moondream.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. moondream.py +179 -0
moondream.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .vision_encoder import VisionEncoder
3
+ from .configuration_moondream import MoondreamConfig
4
+ from transformers import PreTrainedModel
5
+ import re
6
+
7
+ from .modeling_phi import PhiForCausalLM
8
+ from .configuration_moondream import PhiConfig
9
+
10
+ class Moondream(PreTrainedModel):
11
+ config_class = MoondreamConfig
12
+ _supports_flash_attn_2 = True
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.vision_encoder = VisionEncoder()
17
+
18
+ if type(config.phi_config) == dict:
19
+ phi_config = PhiConfig(
20
+ **config.phi_config, attn_implementation=config._attn_implementation
21
+ )
22
+ else:
23
+ phi_config = config.phi_config
24
+ self.text_model = PhiForCausalLM(phi_config)
25
+
26
+ @property
27
+ def device(self):
28
+ return self.text_model.device
29
+
30
+ def encode_image(self, image):
31
+ return self.vision_encoder(image)
32
+
33
+ def input_embeds(self, prompt, image_embeds, tokenizer):
34
+ def _tokenize(txt):
35
+ return tokenizer(
36
+ txt, return_tensors="pt", add_special_tokens=False
37
+ ).input_ids.to(self.device)
38
+
39
+ text_emb = self.text_model.get_input_embeddings()
40
+
41
+ # Add BOS token
42
+ embeds = []
43
+ embeds.append(
44
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
45
+ )
46
+
47
+ if "<image>" not in prompt:
48
+ embeds.append(text_emb(_tokenize(prompt)))
49
+ else:
50
+ assert prompt.count("<image>") == 1
51
+ before, after = prompt.split("<image>")
52
+ if len(before) > 0:
53
+ embeds.append(text_emb(_tokenize(before)))
54
+ embeds.append(image_embeds.to(self.device))
55
+ if len(after) > 0:
56
+ embeds.append(text_emb(_tokenize(after)))
57
+
58
+ return torch.cat(embeds, dim=1)
59
+
60
+ def generate(
61
+ self,
62
+ image_embeds,
63
+ prompt,
64
+ tokenizer,
65
+ eos_text="<END>",
66
+ max_new_tokens=128,
67
+ **kwargs,
68
+ ):
69
+ eos_tokens = tokenizer(eos_text, add_special_tokens=False)[0].ids
70
+
71
+ generate_config = {
72
+ "eos_token_id": eos_tokens,
73
+ "bos_token_id": tokenizer.bos_token_id,
74
+ "pad_token_id": tokenizer.eos_token_id,
75
+ "max_new_tokens": max_new_tokens,
76
+ **kwargs,
77
+ }
78
+
79
+ with torch.no_grad():
80
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
81
+ output_ids = self.text_model.generate(
82
+ inputs_embeds=inputs_embeds, **generate_config
83
+ )
84
+
85
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
86
+
87
+ def answer_question(
88
+ self,
89
+ image_embeds,
90
+ question,
91
+ tokenizer,
92
+ chat_history="",
93
+ result_queue=None,
94
+ **kwargs,
95
+ ):
96
+ prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
97
+ answer = self.generate(
98
+ image_embeds,
99
+ prompt,
100
+ eos_text="<END>",
101
+ tokenizer=tokenizer,
102
+ max_new_tokens=512,
103
+ **kwargs,
104
+ )[0]
105
+ cleaned_answer = re.sub("<$|<END$", "", answer).strip()
106
+
107
+ # Use the result_queue to pass the result if it is provided
108
+ if result_queue:
109
+ result_queue.put(cleaned_answer)
110
+ else:
111
+ return cleaned_answer
112
+
113
+ def batch_answer(
114
+ self,
115
+ images,
116
+ prompts,
117
+ tokenizer,
118
+ **kwargs,
119
+ ):
120
+ eos_tokens = tokenizer("<END>", add_special_tokens=False)[0].ids
121
+
122
+ image_embeds = self.encode_image(images)
123
+
124
+ templated_prompts = [
125
+ f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
126
+ ]
127
+ prompt_embs = [
128
+ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
129
+ for prompt, image_embed in zip(templated_prompts, image_embeds)
130
+ ]
131
+
132
+ bos_emb = prompt_embs[0][0]
133
+ max_len = max([p.shape[0] for p in prompt_embs])
134
+
135
+ inputs_embeds = torch.cat(
136
+ [
137
+ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
138
+ for p in prompt_embs
139
+ ],
140
+ dim=0,
141
+ )
142
+ attention_mask = torch.cat(
143
+ [
144
+ torch.cat(
145
+ [
146
+ torch.zeros(
147
+ 1,
148
+ max_len - p.shape[0],
149
+ device=self.device,
150
+ dtype=torch.long,
151
+ ),
152
+ torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
153
+ ],
154
+ dim=1,
155
+ )
156
+ for p in prompt_embs
157
+ ],
158
+ dim=0,
159
+ )
160
+
161
+ generate_config = {
162
+ "eos_token_id": eos_tokens,
163
+ "bos_token_id": tokenizer.bos_token_id,
164
+ "pad_token_id": tokenizer.eos_token_id,
165
+ "max_new_tokens": 512,
166
+ **kwargs,
167
+ }
168
+
169
+ with torch.no_grad():
170
+ output_ids = self.text_model.generate(
171
+ inputs_embeds=inputs_embeds,
172
+ attention_mask=attention_mask,
173
+ **generate_config,
174
+ )
175
+
176
+ return [
177
+ re.sub("<$|<END$", "", x).strip()
178
+ for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
179
+ ]