import copy import os import sys dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, dir_path) import contextlib import torch.utils.checkpoint import torch.nn as nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from PIL import Image from .modeling_vit import * from .modeling_InternLM import * from .modeling_utils import * from .resampler import create_resampler from transformers.utils import logging logger = logging.get_logger(__name__) class InternLMXComposerForCausalLM(PreTrainedModel): config_class = InternLMXComposerConfig _auto_class = "AutoModelForCausalLM" gen_config = dict( num_beams=5, do_sample=True, min_length=1, repetition_penalty=1.5, length_penalty=1.0, temperature=1.0, max_new_tokens=500, ) def __init__(self, config): super().__init__(config) self.max_length = config.max_length print (f'Set max length to {self.max_length}') print('Init VIT ... ', end='') self.visual_encoder = create_eva_vit_g(img_size=448) self.ln_vision = nn.Identity() self.supports_gradient_checkpointing = True print('Done') print('Init Perceive Sampler ... ', end='') with all_logging_disabled(): self.Qformer = create_resampler(num_query_token=256) print('Done') print('Init InternLM ... ', end='') self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096])) self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096])) self.flag_image_start.requires_grad = False self.flag_image_end.requires_grad = False if int(torch.__version__[0]) == 1: self.internlm_model = InternLMForCausalLM._from_config(config).to( torch.float16) else: assert int(torch.__version__[0]) == 2 # speed up init llm with torch.device('meta'): self.internlm_model = InternLMForCausalLM._from_config(config) # self.internlm_model.to_empty(device=config.device).to(torch.float16) # self.internlm_model.tie_weights() # self.internlm_model.to(config.device) self.internlm_proj = nn.Linear(4096, self.internlm_model.config.hidden_size) print('Done') self.vis_processor = transforms.Compose([ transforms.Resize((448, 448), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) self.tokenizer = None @property def eoh(self): return '' @property def eoa(self): return '' def get_input_embeddings(self): return self.internlm_model.get_input_embeddings() def _set_gradient_checkpointing(self, module, value=False): if value: self.internlm_model.apply( partial(self.internlm_model._set_gradient_checkpointing, value=True) ) def encode_img(self, image): if image is None: return None if isinstance(image, str): image = Image.open(image).convert("RGB") image = self.vis_processor(image).unsqueeze(0).to(self.device) else: assert isinstance(image, torch.Tensor) device = image.device image_embeds = self.ln_vision( self.visual_encoder(image)).to(device) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) query_output = self.Qformer(image_embeds) inputs_internlm = self.internlm_proj(query_output) inputs_internlm = torch.cat([ self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1), inputs_internlm, self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1) ], dim=1) return inputs_internlm def encode_text(self, text, add_special_tokens=False): text_token_ids = self.tokenizer( text, return_tensors='pt', add_special_tokens=add_special_tokens, ).input_ids.to(self.device) text_embeds = self.internlm_model.model.embed_tokens(text_token_ids) return text_embeds def decode_text(self, out_embeds): out_text = self.tokenizer.batch_decode(out_embeds, skip_special_tokens=True)[0] out_text = out_text.split(self.eoa)[0] return out_text def wrap_text(self, user_text, bot_text='', add_special=True): if add_special: eoh = self.eoh else: eoh = '' text = f'<|User|>:{user_text}{eoh}\n<|Bot|>:{bot_text}' return text def get_gen_args(self, **kwargs): new_kargs = copy.deepcopy(self.gen_config) new_kargs.update(kwargs) return new_kargs def generate(self, text, image=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_img(image) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds) out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) return out_text def chat(self, text, image=None, history=None, **kwargs): text_embeds = self.encode_text(text) img_embeds = self.encode_img(image) prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, history=history) out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs)) out_text = self.decode_text(out_embeds) # trunc at eoh and eoa clean_out_text_token_ids = self.tokenizer( out_text, return_tensors='pt').input_ids.to(self.device) clean_out_text_embeds = self.internlm_model.model.embed_tokens( clean_out_text_token_ids) clean_prompt_embeds = self.wrap_prompt(text_embeds, img_embeds, add_special=False) cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], dim=1) if history is None: history = [] history.append(cur_history) return out_text, history def wrap_prompt(self, text_embeds, img_embeds=None, history=None, add_special=True): if add_special: prompt_segs = ['<|User|>:', f'{self.eoh}\n<|Bot|>:'] else: prompt_segs = ['<|User|>:', '<|Bot|>:'] # used in wrap history prompt_seg_embeds = [] for i, seg in enumerate(prompt_segs): if history is not None: add_special_tokens = False else: add_special_tokens = i == 0 seg_embeds = self.encode_text( seg, add_special_tokens=add_special_tokens) prompt_seg_embeds.append(seg_embeds) if img_embeds is None: img_embeds = text_embeds.new_empty(text_embeds.size(0), 0, text_embeds.size(-1)) prompt_seg_embeds = [ prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1] ] prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) if history is not None: prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) return prompt_embeds