jiajunlong commited on
Commit
0bff7b9
1 Parent(s): 879d393
Files changed (1) hide show
  1. modeling_tinyllava_elm.py +225 -3
modeling_tinyllava_elm.py CHANGED
@@ -1,7 +1,14 @@
1
  from dataclasses import dataclass
 
2
  from typing import List, Optional, Tuple, Union
3
  import ast
4
  import re
 
 
 
 
 
 
5
 
6
  import torch
7
  import torch.utils.checkpoint
@@ -12,11 +19,10 @@ from transformers import PreTrainedModel
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
13
  from transformers.generation.utils import GenerateOutput
14
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
 
15
 
16
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
17
 
18
- from transformers import AutoConfig, AutoModelForCausalLM
19
-
20
  # from tinyllava.utils.data_utils import get_value_from_kwargs
21
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
22
  WORKER_HEART_BEAT_INTERVAL = 15
@@ -47,6 +53,169 @@ import numpy as np
47
  from transformers import PretrainedConfig, AutoTokenizer
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def make_divisible(
51
  v: Union[float, int],
52
  divisor: Optional[int] = 8,
@@ -1686,10 +1855,63 @@ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
1686
  position_ids = None
1687
 
1688
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1689
 
1690
 
1691
 
1692
 
1693
 
1694
  AutoConfig.register("tinyllava", TinyLlavaConfig)
1695
- AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)
 
1
  from dataclasses import dataclass
2
+ import dataclasses
3
  from typing import List, Optional, Tuple, Union
4
  import ast
5
  import re
6
+ from enum import auto, Enum
7
+ import requests
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import base64
11
+ import time
12
 
13
  import torch
14
  import torch.utils.checkpoint
 
19
  from transformers.modeling_outputs import CausalLMOutputWithPast
20
  from transformers.generation.utils import GenerateOutput
21
  from transformers import CLIPVisionModel, CLIPImageProcessor,SiglipVisionModel, SiglipImageProcessor
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
 
24
  from .configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
25
 
 
 
26
  # from tinyllava.utils.data_utils import get_value_from_kwargs
27
  CONTROLLER_HEART_BEAT_EXPIRATION = 30
28
  WORKER_HEART_BEAT_INTERVAL = 15
 
53
  from transformers import PretrainedConfig, AutoTokenizer
54
 
55
 
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ # Model Constants
60
+ IGNORE_INDEX = -100
61
+ IMAGE_TOKEN_INDEX = -200
62
+ DEFAULT_IMAGE_TOKEN = "<image>"
63
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
64
+ DEFAULT_IM_START_TOKEN = "<im_start>"
65
+ DEFAULT_IM_END_TOKEN = "<im_end>"
66
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
67
+
68
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
69
+ WORKER_HEART_BEAT_INTERVAL = 15
70
+ LOGDIR = "."
71
+
72
+
73
+ class SeparatorStyle(Enum):
74
+ """Different separator style."""
75
+ SINGLE = auto()
76
+ TWO = auto()
77
+ MPT = auto()
78
+ PLAIN = auto()
79
+ LLAMA_2 = auto()
80
+ TINY_LLAMA = auto()
81
+ QWEN_2 = auto()
82
+
83
+
84
+ @dataclasses.dataclass
85
+ class Conversation:
86
+ """A class that keeps all conversation history."""
87
+ system: str
88
+ roles: List[str]
89
+ messages: List[List[str]]
90
+ offset: int
91
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
92
+ sep: str = "###"
93
+ sep2: str = None
94
+ version: str = "Unknown"
95
+
96
+ skip_next: bool = False
97
+
98
+ def get_prompt(self):
99
+ messages = self.messages
100
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
101
+ messages = self.messages.copy()
102
+ init_role, init_msg = messages[0].copy()
103
+ init_msg = init_msg[0].replace("<image>", "").strip()
104
+ if 'mmtag' in self.version:
105
+ messages[0] = (init_role, init_msg)
106
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
107
+ messages.insert(1, (self.roles[1], "Received."))
108
+ else:
109
+ messages[0] = (init_role, "<image>\n" + init_msg)
110
+
111
+ if self.sep_style == SeparatorStyle.TWO:
112
+ seps = [self.sep, self.sep2]
113
+ ret = self.system + seps[0]
114
+ for i, (role, message) in enumerate(messages):
115
+ if message:
116
+ if type(message) is tuple:
117
+ message, _, _ = message
118
+ ret += role + ": " + message + seps[i % 2]
119
+ else:
120
+ ret += role + ":"
121
+ else:
122
+ raise ValueError(f"Invalid style: {self.sep_style}")
123
+
124
+ return ret
125
+
126
+ def append_message(self, role, message):
127
+ self.messages.append([role, message])
128
+
129
+ def copy(self):
130
+ return Conversation(
131
+ system=self.system,
132
+ roles=self.roles,
133
+ messages=[[x, y] for x, y in self.messages],
134
+ offset=self.offset,
135
+ sep_style=self.sep_style,
136
+ sep=self.sep,
137
+ sep2=self.sep2,
138
+ version=self.version)
139
+
140
+
141
+
142
+
143
+ conv_phi_v0 = Conversation(
144
+ system="A chat between a curious user and an artificial intelligence assistant. "
145
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
146
+ roles=("USER", "ASSISTANT"),
147
+ version="phi",
148
+ messages=(),
149
+ offset=0,
150
+ sep_style=SeparatorStyle.TWO,
151
+ sep=" ",
152
+ sep2="<|endoftext|>",
153
+ )
154
+
155
+
156
+ def load_image_from_base64(image):
157
+ return Image.open(BytesIO(base64.b64decode(image)))
158
+
159
+
160
+ def expand2square(pil_img, background_color):
161
+ width, height = pil_img.size
162
+ if width == height:
163
+ return pil_img
164
+ elif width > height:
165
+ result = Image.new(pil_img.mode, (width, width), background_color)
166
+ result.paste(pil_img, (0, (width - height) // 2))
167
+ return result
168
+ else:
169
+ result = Image.new(pil_img.mode, (height, height), background_color)
170
+ result.paste(pil_img, ((height - width) // 2, 0))
171
+ return result
172
+
173
+
174
+ def process_images(images, image_processor, model_cfg):
175
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
176
+ new_images = []
177
+ if image_aspect_ratio == 'pad':
178
+ for image in images:
179
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
180
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
181
+ new_images.append(image)
182
+ else:
183
+ return image_processor(images, return_tensors='pt')['pixel_values']
184
+ if all(x.shape == new_images[0].shape for x in new_images):
185
+ new_images = torch.stack(new_images, dim=0)
186
+ return new_images
187
+
188
+
189
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
190
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
191
+
192
+ def insert_separator(X, sep):
193
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
194
+
195
+ input_ids = []
196
+ offset = 0
197
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
198
+ offset = 1
199
+ input_ids.append(prompt_chunks[0][0])
200
+
201
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
202
+ input_ids.extend(x[offset:])
203
+
204
+ if return_tensors is not None:
205
+ if return_tensors == 'pt':
206
+ return torch.tensor(input_ids, dtype=torch.long)
207
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
208
+ return input_ids
209
+
210
+ def load_image(image_file):
211
+ if image_file.startswith("http") or image_file.startswith("https"):
212
+ response = requests.get(image_file)
213
+ image = Image.open(BytesIO(response.content)).convert("RGB")
214
+ else:
215
+ image = Image.open(image_file).convert("RGB")
216
+ return image
217
+
218
+
219
  def make_divisible(
220
  v: Union[float, int],
221
  divisor: Optional[int] = 8,
 
1855
  position_ids = None
1856
 
1857
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1858
+
1859
+ def chat(
1860
+ self,
1861
+ prompt: str,
1862
+ tokenizer = None,
1863
+ image: str = None,
1864
+ max_new_tokens: int = 512,
1865
+ num_beams = 1,
1866
+ top_p=None,
1867
+ temperature=0
1868
+ ):
1869
+ image_processor = self.vision_tower._image_processor
1870
+
1871
+ if image is not None:
1872
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
1873
+ conv = conv_phi_v0.copy()
1874
+ conv.append_message(conv.roles[0], prompt)
1875
+ conv.append_message(conv.roles[1], None)
1876
+ prompt = conv.get_prompt()
1877
+ if image is not None:
1878
+ image = load_image(image)
1879
+ image_tensor = process_images(image, image_processor, self.config).to(self.device)
1880
+
1881
+ input_ids = (
1882
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
1883
+ .unsqueeze(0).to(self.device)
1884
+ )
1885
+ # Generate
1886
+ stime = time.time()
1887
+
1888
+ with torch.inference_mode():
1889
+ output_ids = self.generate(
1890
+ input_ids,
1891
+ images=image_tensor,
1892
+ do_sample=True if temperature > 0 else False,
1893
+ temperature=temperature,
1894
+ top_p=top_p,
1895
+ num_beams=num_beams,
1896
+ pad_token_id=tokenizer.pad_token_id,
1897
+ max_new_tokens=max_new_tokens,
1898
+ use_cache=True,
1899
+ # stopping_criteria=[stopping_criteria],
1900
+ )
1901
+
1902
+ # print('inference over')
1903
+ generation_time = time.time() - stime
1904
+ outputs = tokenizer.batch_decode(
1905
+ output_ids, skip_special_tokens=True
1906
+ )[0]
1907
+
1908
+ outputs = outputs.strip()
1909
+
1910
+ return outputs, generation_time
1911
 
1912
 
1913
 
1914
 
1915
 
1916
  AutoConfig.register("tinyllava", TinyLlavaConfig)
1917
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)