360VL_PHI / modeling_360vl.py
Secur3's picture
trimmed down sys prompt
6943b72
raw
history blame
35.7 kB
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from PIL import Image
from abc import ABC, abstractmethod
import os
import math
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from functools import partial
from transformers.configuration_utils import PretrainedConfig
from timm.models.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
from torch.nn import functional as F
import math
from einops import rearrange
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
class HoneybeeVisualProjectorConfig(PretrainedConfig):
model_type = "mllm_visual_projector"
def __init__(
self,
projector_type: str = "resampler",
hidden_size: int = 1024, #
num_hidden_layers: int = 6, #
num_attention_heads: int = 16, #
intermediate_size: int = 4096, #
attention_probs_dropout_prob: float = 0.1, #
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-6, #
encoder_hidden_size: int = 1024, # This will be overwritten by vision_model's hidden_size
pos_emb=False,
feature_layer_index=-1, # vision feature layer index; -1: last layer
num_eos_tokens=1,
use_cls=True,
prenorm=False,
**kwargs,
):
super().__init__(**kwargs)
self.projector_type = projector_type
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.encoder_hidden_size = encoder_hidden_size
self.pos_emb = pos_emb
self.feature_layer_index = feature_layer_index
self.num_eos_tokens = num_eos_tokens
self.use_cls = use_cls
self.prenorm = prenorm
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the visual_projector config dict if we are loading from HoneybeeConfig
if config_dict.get("model_type") == "QH_360VL":
config_dict = config_dict["visual_projector_config"]
return cls.from_dict(config_dict, **kwargs)
def build_pos_embeds(
config: HoneybeeVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int
):
# pos emb
# true
if config.pos_emb:
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
else:
pos_emb = None
return pos_emb
def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: int):
# think tokens
num_eos_tokens = config.num_eos_tokens
# 0
if num_eos_tokens:
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
else:
eos_tokens = None
return eos_tokens
def build_prenorm(config: HoneybeeVisualProjectorConfig):
# false
if config.prenorm:
prenorm = LayerNorm(config.encoder_hidden_size)
else:
prenorm = None
return prenorm
def build_mlp(depth, hidden_size, output_hidden_size):
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*layers)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
# 16,24
src_size = int(math.sqrt(abs_pos.size(1)))
# 32,48
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
class Projector(nn.Module):
"""Base projector class"""
def __init__(
self,
config: HoneybeeVisualProjectorConfig,
num_input_tokens: int,
output_hidden_size: int,
):
super().__init__()
self.config = config
self.num_input_tokens = num_input_tokens
self.output_hidden_size = output_hidden_size
# think tokens
self.eos_tokens = build_eos_tokens(config, output_hidden_size)
# pos emb
self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size)
self.prenorm = build_prenorm(config)
self.build_net()
def build_net(self):
raise NotImplementedError()
def _forward(self, x):
raise NotImplementedError()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token.
"""
if self.prenorm is not None:
x = self.prenorm(x)
if self.pos_emb is not None:
# self.pos_emb = self.pos_emb[:,1:]
pos_emb = get_abs_pos(self.pos_emb[:,1:], x.size(1))
pos_emb = pos_emb.to(device=x.device)
x += pos_emb
x = self._forward(x) # (B, L, output_hidden_size)
B = x.size(0)
if self.eos_tokens is not None:
x = torch.cat([x, self.eos_tokens.expand(B, -1, -1)], dim=1)
return x
class ConvProjector(Projector):
def _forward(self, x):
# x: [B, L, dim]
# x = x[:, 1:] # drop cls token and 2d forward
hw = int(x.size(1) ** 0.5)
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
x = self.net(x)
x = rearrange(x, "b d h w -> b (h w) d")
x = self.readout(x)
return x
class CAbstractor(ConvProjector):
"""C-Abstractor"""
def build_net(self):
encoder_hidden_size = self.config.encoder_hidden_size
hidden_size = self.config.hidden_size
output_hidden_size = self.output_hidden_size
depth = self.config.depth
mlp_depth = self.config.mlp_depth
n_queries = self.config.num_queries
assert (n_queries ** 0.5).is_integer(), "n_queries must be square number"
hw = int(n_queries ** 0.5)
# RegBlock = ResBlock + SE
RegBlock = partial(
RegStage,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
s1 = RegBlock(
depth,
encoder_hidden_size,
hidden_size,
)
sampler = nn.AdaptiveAvgPool2d((hw, hw))
s2 = RegBlock(
depth,
hidden_size,
hidden_size,
)
self.net = nn.Sequential(s1, sampler, s2)
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_honeybee_projector(config, projector_type, num_tokens,lm_hidden_size):
"""Build projector (abstractor) and query_tokens (optionally for resampler)"""
proj_config = config
proj_type = projector_type
num_tokens = num_tokens
output_hidden_size = lm_hidden_size # LM hidden size
abstractor = {
"c-abs": CAbstractor,
}[
proj_type
](proj_config, num_tokens, output_hidden_size)
return abstractor
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
if projector_type == 'c-abs':
local_config_path = config.mm_projector_config
honeybee_config = HoneybeeVisualProjectorConfig.from_pretrained(local_config_path)
num_tokens = config.mm_num_tokens
lm_hidden_size = config.hidden_size
abstractor = build_honeybee_projector(honeybee_config,projector_type,num_tokens,lm_hidden_size)
return abstractor
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
class QH360_VL_MetaModel:
def __init__(self, config):
super(QH360_VL_MetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector_ctt = build_vision_projector(config)
self.mm_projector_ori = build_vision_projector(config)
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
class QH360_VL_MetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def encode_images_noprojector(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = image_features.detach()
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, images
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
return input_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
image_features = []
for image in images:
if image.ndim == 3:
image_features.append(self.encode_images(image.unsqueeze(0)).squeeze(0))
elif image.ndim == 4:
#NOTE cc-plan
temp_feats = self.encode_images_noprojector(image)
src_size = int(math.sqrt(temp_feats.shape[1]))
temp_feats = temp_feats.reshape(temp_feats.shape[0]//5,5,-1, temp_feats.shape[-1])
x1 = temp_feats[:,4,:,:]
x = temp_feats[:,:4,:,:]
x = x.reshape(x.shape[0], -1, src_size, src_size, x.shape[-1])
x = x.transpose(1,2).reshape(x.shape[0], src_size,2,2, src_size, x.shape[-1])
x = x.transpose(1,2).reshape(x.shape[0], -1, x.shape[-1])
x1 = self.get_model().mm_projector_ori(x1).squeeze(0)
x = self.get_model().mm_projector_ctt(x).squeeze(0)
temp_feats_all = torch.cat([x,x1],dim=0)
image_features.append(temp_feats_all)
else:
image_features = self.encode_images(images)
new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
new_input_embeds.append(cur_input_embeds)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
if labels is not None:
cur_labels = labels[batch_idx]
cur_new_labels = []
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
image_token_start = image_token_indices[0]
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
cur_labels = cur_labels[image_token_start+2:]
else:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_labels = cur_labels[image_token_start+1:]
cur_image_idx += 1
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_input_ids = cur_input_ids[image_token_start+2:]
else:
cur_input_ids = cur_input_ids[image_token_start+1:]
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
if cur_input_ids.numel() > 0:
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
else:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
if labels is not None:
cur_new_labels.append(cur_labels)
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
new_input_embeds.append(cur_new_input_embeds)
if labels is not None:
cur_new_labels = torch.cat(cur_new_labels, dim=0)
new_labels.append(cur_new_labels)
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
max_len = max(x.shape[0] for x in new_input_embeds)
new_input_embeds_align = []
for cur_new_embed in new_input_embeds:
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
if labels is not None:
new_labels_align = []
_new_labels = new_labels
for cur_new_label in new_labels:
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
new_labels_align.append(cur_new_label)
new_labels = torch.stack(new_labels_align, dim=0)
if attention_mask is not None:
new_attention_mask = []
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
new_attention_mask.append(cur_new_attention_mask)
attention_mask = torch.stack(new_attention_mask, dim=0)
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
assert attention_mask.shape == new_input_embeds.shape[:2]
return None, attention_mask, past_key_values, new_input_embeds, new_labels
class QH360_VLConfig(LlamaConfig):
model_type = "QH_360VL"
class QH360_VL_LlamaModel(QH360_VL_MetaModel, LlamaModel):
config_class = QH360_VLConfig
def __init__(self, config: LlamaConfig):
super(QH360_VL_LlamaModel, self).__init__(config)
class QH360_VL_LlamaForCausalLM(LlamaForCausalLM, QH360_VL_MetaForCausalLM):
config_class = QH360_VLConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
config._attn_implementation == "flash_attention_2"
self.model = QH360_VL_LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def build_conversation_input_ids(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
image = None,
image_processor=None,
):
"""
sysp = (
'You are an expert identifying and clasifying information as Protected Health Information (PHI) based on the following criteria: '
'Inferred PHI encompasses information about a specific disease/condition/medical diagnosis, in the context '
'that the information MUST directly identify the specific disease/condition/medical diagnosis and MUST provide '
'treatments/services/specific information about the specific disease/condition/medical diagnosis; you MUST identify an actual disease/condition/diagnosis in the information to be considered PHI. '
'EXAMPLE: A search page for providers treating a single specific condition/diagnosis, such as colon cancer, would be PHI EXCEPT as outlined as in "EXCEPTIONS" as follows. '
'EXAMPLE: A page, directly containing detailed information about a single specific condition/treatment/diagnosis, such as a unique treatment for colon cancer, '
'would be be PHI, EXCEPT as outlined as in "EXCEPTIONS" as follows. '
'EXAMPLE: A patient login portal for a speciality medical provider would be PHI. '
'EXCEPTIONS: General medical information, such as a list of hospital or medical provider services offered, '
'without specific conditions/diagnoses detailed is NOT PHI. '
'EXCEPTIONS: Information about individuals, including images of individuals and those appearing to be in the '
'medical profession, without details of their (or a) specific medical conditions is NOT PHI. '
'EXCEPTIONS: News articles about specific medical conditions is NOT PHI. '
'EXCEPTIONS: General service listings, like emergency medical services, transportation services, etc. are NOT PHI. '
'EXCEPTIONS: Information about common diseases/conditions/diagnosis, such as the seasonal flu, common cold, covid, aging, etc. is NOT PHI. '
'EXCEPTIONS: Finding words or phrases associated with medical services including medical professions, without a SPECIFIC medical condition, diagnosis or treatment included, is NOT PHI. '
'EXCEPTIONS: Job postings are NOT PHI. Blog articles are NOT PHI. News articles are NOT PHI. Media postings are NOT PHI. '
'EXCEPTIONS: Information about a specific individual or individuals, contained in a news article, blog posting or media posting, is NOT PHI. '
'EXCEPTIONS: Posters and/or event announcments about medical conditions and/or events related to medical conditions, for example a benefit for breast cancer survivors, is NOT PHI. '
'EXCEPTION EXAMPLE: An image of a women wearing doctor clothes without the identification of a specific medical condition would NOT BE PHI'
'Typically you would have identified NO applicable EXCEPTIONS in order to affirmatively answer "Yes" to a question such as "Can you infer PHI from this page?". '
'EXCEPTIONS override potential identification of PHI; if you are unable to definitively determine the presence of PHI based information in the image, then you DID NOT identify PHI and your determination/response should begin with "No". '
'When providing your response it MUST start with "Yes" or "No" based on your review of the image, followed by a brief summary explanation of the rationale for the "Yes" or "No" decision including the information you found supporting that rational; '
'you MUST INCLUDE a brief summary explanation of the rationale for the "Yes" or "No" decision including the information you found supporting that rational.'
''
)
"""
sysp = (
'You are an expert identifying and clasifying information as Protected Health Information (PHI) based on the following criteria: '
'Inferred PHI encompasses information about a specific disease/condition/medical diagnosis, in the context '
'that the information MUST directly identify the specific disease/condition/medical diagnosis and MUST provide '
'treatments/services/specific information about the specific disease/condition/medical diagnosis; you MUST identify an actual disease/condition/diagnosis in the information to be considered PHI. '
'If you are unable to definitively determine the presence of PHI based information in the image, then you DID NOT identify PHI and your determination/response should begin with "No". '
'When providing your response it MUST start with "Yes" or "No" based on your review of the image, followed by a brief summary explanation of the rationale for the "Yes" or "No" decision including the information you found supporting that rational; '
'you MUST INCLUDE a brief summary explanation of the rationale for the "Yes" or "No" decision including the information you found supporting that rational.'
)
input_msg = [
{
"role": "system",
"content": sysp
},
{
"role": "user",
"content": "<|reserved_special_token_44|>"+ '\n' + query
}
]
input_ids = tokenizer.apply_chat_template(
input_msg,
add_generation_prompt=True,
padding="longest",
return_tensors="pt",
)
input_id_list = input_ids[0].tolist()
input_id_list[input_id_list.index(128049)]=-200
input_ids = torch.tensor(input_id_list, dtype=input_ids.dtype,device=input_ids.device)
input_ids = input_ids.unsqueeze(0)
image_tensor = self.process_images_slid_window(image,image_processor).unsqueeze(0)
return {
'input_ids': input_ids,
'image': image_tensor,
}
def process_images_slid_window(self, image, image_processor, vit_is=336):
def get_proper_imgsize(pil_img, vit_is):
max_w_h = vit_is * 2
new_pil_img = pil_img.resize((max_w_h, max_w_h))
return new_pil_img
def tensor_crop(tensor_array, left, upper, right, lower):
# tensor_array: C * H * W
return tensor_array[:, upper:lower, left:right]
def image_slid_window(image, num_slid_window):
# image: tensor, 3 * 336 * 336 or 3 * 672 * 672
# image: tensor, 3 * 224 * 224 or 3 * 448 * 448
if num_slid_window == 5:
image_x2, image_x1 = image[0], image[1]
vit_is = image_x1.shape[1]
h, w = image_x2.shape[1],image_x2.shape[2]
image0 = tensor_crop(image_x2, 0, 0, vit_is, vit_is)
image1 = tensor_crop(image_x2, w-vit_is, 0, w, vit_is)
image2 = tensor_crop(image_x2, 0, h-vit_is, vit_is, h)
image3 = tensor_crop(image_x2, w-vit_is, h-vit_is, w, h)
return torch.stack([image0, image1, image2, image3, image_x1])
else:
return image
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
vit_is = vit_is # vit_input_size, for simplicity
num_slid_window = 5
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = get_proper_imgsize(image, vit_is)
image_x2 = image_processor.preprocess(image, return_tensors='pt', do_resize=False, do_center_crop=False)['pixel_values'][0]
image_x1 = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image = [image_x2, image_x1]
image = image_slid_window(image, num_slid_window)
return image