myownskyW7
commited on
Commit
•
8a8a3ae
1
Parent(s):
e782b1f
Add fine-tuning code
Browse files- modeling_InternLM.py +96 -15
- modeling_InternLM_XComposer.py +149 -7
- modeling_utils.py +29 -22
modeling_InternLM.py
CHANGED
@@ -16,6 +16,11 @@ from transformers.utils import logging
|
|
16 |
from .configuration_InternLM_XComposer import InternLMXComposerConfig
|
17 |
from .modeling_utils import LoRALinear
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
logger = logging.get_logger(__name__)
|
20 |
|
21 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
@@ -31,7 +36,6 @@ def rotary_embed(x1, x2, cos, sin, conj):
|
|
31 |
|
32 |
|
33 |
class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
34 |
-
|
35 |
@staticmethod
|
36 |
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
37 |
"""
|
@@ -51,18 +55,26 @@ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
51 |
assert seqlen <= rotary_seqlen
|
52 |
cos_k = cos if cos_k is None else cos_k
|
53 |
sin_k = sin if sin_k is None else sin_k
|
54 |
-
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
|
|
|
55 |
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
56 |
-
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2],
|
|
|
|
|
57 |
# rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
58 |
# rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
59 |
-
q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
|
|
60 |
qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
|
61 |
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
62 |
-
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2],
|
|
|
|
|
63 |
# rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
64 |
# rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
65 |
-
k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen],
|
|
|
|
|
66 |
qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
|
67 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
68 |
ctx.interleaved = interleaved
|
@@ -75,18 +87,69 @@ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
75 |
rotary_dim = cos.shape[-1]
|
76 |
rotary_dim *= 2
|
77 |
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
78 |
-
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
79 |
-
|
80 |
-
rotary_emb.apply_rotary(dq1, dq2,
|
81 |
-
rearrange(
|
|
|
|
|
82 |
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
83 |
-
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
84 |
-
|
85 |
-
rotary_emb.apply_rotary(dk1, dk2,
|
86 |
-
rearrange(
|
|
|
|
|
87 |
return dqkv, None, None, None, None, None
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
91 |
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
92 |
""" """
|
@@ -137,6 +200,23 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
137 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
138 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
def eval_forward(self, qkv, seqlen_offset=0):
|
141 |
"""
|
142 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
@@ -157,6 +237,7 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
|
|
160 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
161 |
|
162 |
|
@@ -1241,6 +1322,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
1241 |
reordered_past = ()
|
1242 |
for layer_past in past_key_values:
|
1243 |
reordered_past += (tuple(
|
1244 |
-
past_state.index_select(0, beam_idx)
|
1245 |
for past_state in layer_past), )
|
1246 |
return reordered_past
|
|
|
16 |
from .configuration_InternLM_XComposer import InternLMXComposerConfig
|
17 |
from .modeling_utils import LoRALinear
|
18 |
|
19 |
+
try:
|
20 |
+
import rotary_emb
|
21 |
+
except Exception as e:
|
22 |
+
print('Please following docs/install.md to install rotary_emb if you want to do fine-tuning')
|
23 |
+
|
24 |
logger = logging.get_logger(__name__)
|
25 |
|
26 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
|
|
36 |
|
37 |
|
38 |
class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
39 |
@staticmethod
|
40 |
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
41 |
"""
|
|
|
55 |
assert seqlen <= rotary_seqlen
|
56 |
cos_k = cos if cos_k is None else cos_k
|
57 |
sin_k = sin if sin_k is None else sin_k
|
58 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
|
59 |
+
rotary_dim // 2)
|
60 |
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
61 |
+
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2],
|
62 |
+
q_ro[...,
|
63 |
+
1::2])
|
64 |
# rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
65 |
# rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
66 |
+
q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
67 |
+
rearrange(sin[:seqlen], 's d -> s 1 d'), False)
|
68 |
qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
|
69 |
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
70 |
+
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2],
|
71 |
+
k_ro[...,
|
72 |
+
1::2])
|
73 |
# rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
74 |
# rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
75 |
+
k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen],
|
76 |
+
's d -> s 1 d'),
|
77 |
+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
|
78 |
qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
|
79 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
80 |
ctx.interleaved = interleaved
|
|
|
87 |
rotary_dim = cos.shape[-1]
|
88 |
rotary_dim *= 2
|
89 |
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
90 |
+
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved else
|
91 |
+
(dq_ro[..., ::2], dq_ro[..., 1::2]))
|
92 |
+
rotary_emb.apply_rotary(dq1, dq2,
|
93 |
+
rearrange(cos[:seqlen], 's d -> s 1 d'),
|
94 |
+
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1,
|
95 |
+
dq2, True)
|
96 |
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
97 |
+
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved else
|
98 |
+
(dk_ro[..., ::2], dk_ro[..., 1::2]))
|
99 |
+
rotary_emb.apply_rotary(dk1, dk2,
|
100 |
+
rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
101 |
+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1,
|
102 |
+
dk2, True)
|
103 |
return dqkv, None, None, None, None, None
|
104 |
|
105 |
|
106 |
+
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
107 |
+
"""
|
108 |
+
ApplyRotaryEmbQKV_
|
109 |
+
"""
|
110 |
+
@staticmethod
|
111 |
+
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
112 |
+
"""
|
113 |
+
qkv: (total, 3, nheads, headdim)
|
114 |
+
cos, sin: (seqlen, rotary_dim / 2)
|
115 |
+
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
116 |
+
rotary_dim must be <= headdim
|
117 |
+
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
118 |
+
"""
|
119 |
+
_, three, _, headdim = qkv.shape
|
120 |
+
assert three == 3
|
121 |
+
rotary_seqlen, rotary_dim = cos.shape
|
122 |
+
rotary_dim *= 2
|
123 |
+
assert rotary_dim <= headdim
|
124 |
+
cos_k = cos if cos_k is None else cos_k
|
125 |
+
sin_k = sin if sin_k is None else sin_k
|
126 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
|
127 |
+
rotary_dim // 2)
|
128 |
+
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
129 |
+
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"),
|
130 |
+
rearrange(sin, "s d -> s 1 d"), q1, q2, False)
|
131 |
+
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
132 |
+
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k, "s d -> s 1 d"),
|
133 |
+
rearrange(sin_k, "s d -> s 1 d"), k1, k2,
|
134 |
+
False)
|
135 |
+
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
136 |
+
return qkv
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def backward(ctx, dqkv):
|
140 |
+
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
141 |
+
rotary_dim = cos.shape[-1]
|
142 |
+
rotary_dim *= 2
|
143 |
+
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
144 |
+
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos, "s d -> s 1 d"),
|
145 |
+
rearrange(sin, "s d -> s 1 d"), dq1, dq2, True)
|
146 |
+
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
147 |
+
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k, "s d -> s 1 d"),
|
148 |
+
rearrange(sin_k, "s d -> s 1 d"), dk1, dk2,
|
149 |
+
True)
|
150 |
+
return dqkv, None, None, None, None
|
151 |
+
|
152 |
+
|
153 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
154 |
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
155 |
""" """
|
|
|
200 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
201 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
202 |
|
203 |
+
def forward(self,
|
204 |
+
qkv: torch.Tensor,
|
205 |
+
indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
206 |
+
self._update_cos_sin_cache(qkv, indexes)
|
207 |
+
if self.scale is None:
|
208 |
+
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
|
209 |
+
self._sin_cached[indexes]).to(
|
210 |
+
qkv.dtype)
|
211 |
+
else:
|
212 |
+
return apply_rotary_emb_qkv_(
|
213 |
+
qkv,
|
214 |
+
self._cos_cached[indexes],
|
215 |
+
self._sin_cached[indexes],
|
216 |
+
self._cos_k_cached[indexes],
|
217 |
+
self._sin_k_cached[indexes],
|
218 |
+
).to(qkv.dtype)
|
219 |
+
|
220 |
def eval_forward(self, qkv, seqlen_offset=0):
|
221 |
"""
|
222 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
|
237 |
)
|
238 |
|
239 |
|
240 |
+
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
241 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
242 |
|
243 |
|
|
|
1322 |
reordered_past = ()
|
1323 |
for layer_past in past_key_values:
|
1324 |
reordered_past += (tuple(
|
1325 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1326 |
for past_state in layer_past), )
|
1327 |
return reordered_past
|
modeling_InternLM_XComposer.py
CHANGED
@@ -46,12 +46,13 @@ conversation
|
|
46 |
def __init__(self, config):
|
47 |
super().__init__(config)
|
48 |
|
49 |
-
|
|
|
50 |
self.visual_encoder = create_eva_vit_g()
|
51 |
self.ln_vision = LayerNorm(self.visual_encoder.num_features)
|
52 |
-
|
53 |
|
54 |
-
|
55 |
with all_logging_disabled():
|
56 |
self.Qformer, self.query_tokens = self.init_qformer(
|
57 |
config.num_query_token, self.visual_encoder.num_features)
|
@@ -61,9 +62,9 @@ conversation
|
|
61 |
layer.output = None
|
62 |
layer.intermediate = None
|
63 |
self.Qformer.cls = None
|
64 |
-
|
65 |
|
66 |
-
|
67 |
self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
|
68 |
self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
|
69 |
self.flag_image_start.requires_grad = False
|
@@ -81,7 +82,7 @@ conversation
|
|
81 |
# speed up init llm
|
82 |
with torch.device('meta'):
|
83 |
self.internlm_model = InternLMForCausalLM._from_config(config)
|
84 |
-
self.internlm_model.to_empty(device=
|
85 |
self.internlm_model.to(config.device)
|
86 |
for n, m in self.internlm_model.named_modules():
|
87 |
if 'lora' in n:
|
@@ -89,7 +90,7 @@ conversation
|
|
89 |
|
90 |
self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
|
91 |
self.internlm_model.config.hidden_size)
|
92 |
-
|
93 |
|
94 |
self.vis_processor = transforms.Compose([
|
95 |
transforms.Resize((224, 224),
|
@@ -111,6 +112,17 @@ conversation
|
|
111 |
[StoppingCriteriaSub(stops=stop_words_ids)])
|
112 |
self.gen_config['stopping_criteria'] = stopping_criteria
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def maybe_autocast(self, dtype=torch.float16):
|
115 |
# if on cpu, don't use autocast
|
116 |
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
@@ -268,3 +280,133 @@ conversation
|
|
268 |
if history is not None:
|
269 |
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
|
270 |
return prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def __init__(self, config):
|
47 |
super().__init__(config)
|
48 |
|
49 |
+
self.max_length = config.max_length
|
50 |
+
rank0_print('Init VIT ... ', end='')
|
51 |
self.visual_encoder = create_eva_vit_g()
|
52 |
self.ln_vision = LayerNorm(self.visual_encoder.num_features)
|
53 |
+
rank0_print('Done')
|
54 |
|
55 |
+
rank0_print('Init Perceive Sampler ... ', end='')
|
56 |
with all_logging_disabled():
|
57 |
self.Qformer, self.query_tokens = self.init_qformer(
|
58 |
config.num_query_token, self.visual_encoder.num_features)
|
|
|
62 |
layer.output = None
|
63 |
layer.intermediate = None
|
64 |
self.Qformer.cls = None
|
65 |
+
rank0_print('Done')
|
66 |
|
67 |
+
rank0_print('Init InternLM ... ', end='')
|
68 |
self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
|
69 |
self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
|
70 |
self.flag_image_start.requires_grad = False
|
|
|
82 |
# speed up init llm
|
83 |
with torch.device('meta'):
|
84 |
self.internlm_model = InternLMForCausalLM._from_config(config)
|
85 |
+
self.internlm_model.to_empty(device=config.device).to(torch.float16)
|
86 |
self.internlm_model.to(config.device)
|
87 |
for n, m in self.internlm_model.named_modules():
|
88 |
if 'lora' in n:
|
|
|
90 |
|
91 |
self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
|
92 |
self.internlm_model.config.hidden_size)
|
93 |
+
rank0_print('Done')
|
94 |
|
95 |
self.vis_processor = transforms.Compose([
|
96 |
transforms.Resize((224, 224),
|
|
|
112 |
[StoppingCriteriaSub(stops=stop_words_ids)])
|
113 |
self.gen_config['stopping_criteria'] = stopping_criteria
|
114 |
|
115 |
+
self.supports_gradient_checkpointing = True
|
116 |
+
|
117 |
+
def get_input_embeddings(self):
|
118 |
+
return self.internlm_model.get_input_embeddings()
|
119 |
+
|
120 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
121 |
+
if value:
|
122 |
+
self.internlm_model.apply(
|
123 |
+
partial(self.internlm_model._set_gradient_checkpointing,
|
124 |
+
value=True))
|
125 |
+
|
126 |
def maybe_autocast(self, dtype=torch.float16):
|
127 |
# if on cpu, don't use autocast
|
128 |
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
|
|
280 |
if history is not None:
|
281 |
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
|
282 |
return prompt_embeds
|
283 |
+
|
284 |
+
######################
|
285 |
+
# code for training
|
286 |
+
######################
|
287 |
+
def prompt_wrap(self, img_embeds, prompt):
|
288 |
+
batch_size = img_embeds.shape[0]
|
289 |
+
p_before, p_after = prompt.split('<ImageHere>')
|
290 |
+
p_before_tokens = self.tokenizer(p_before,
|
291 |
+
return_tensors="pt",
|
292 |
+
add_special_tokens=True).to(
|
293 |
+
img_embeds.device)
|
294 |
+
|
295 |
+
p_before_embeds = self.internlm_model.model.embed_tokens(
|
296 |
+
p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
297 |
+
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds], dim=1)
|
298 |
+
|
299 |
+
wrapped_atts_img = torch.ones(wrapped_img_embeds.size()[:-1],
|
300 |
+
dtype=torch.long).to(img_embeds.device)
|
301 |
+
|
302 |
+
wrapped_target = torch.ones(
|
303 |
+
batch_size, wrapped_img_embeds.shape[1], dtype=torch.long).to(
|
304 |
+
img_embeds.device) * -100
|
305 |
+
|
306 |
+
return wrapped_img_embeds, wrapped_atts_img, wrapped_target
|
307 |
+
|
308 |
+
def align_text(self, samples, has_img=False): ### add eos and eoa
|
309 |
+
text_new = []
|
310 |
+
if has_img: ### remove the first user to wrap image features
|
311 |
+
text = [
|
312 |
+
t.replace("<image>", "").split("<|User|>:", 1)[-1].lstrip()
|
313 |
+
for t in samples["text_input"]
|
314 |
+
]
|
315 |
+
else:
|
316 |
+
text = [t for t in samples["text_input"]]
|
317 |
+
|
318 |
+
text = [t + self.eoa + ' </s>' for t in text]
|
319 |
+
for i in range(len(text)):
|
320 |
+
temp = text[i]
|
321 |
+
temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
|
322 |
+
temp = temp.replace(' <|User|>', self.eoa + ' <|User|>')
|
323 |
+
if temp.find(self.eoh) > temp.find(self.eoa):
|
324 |
+
temp = temp.replace(self.eoa, '', 1)
|
325 |
+
text_new.append(temp)
|
326 |
+
return text_new
|
327 |
+
|
328 |
+
def text2emb(self, text):
|
329 |
+
to_regress_tokens = self.tokenizer(text,
|
330 |
+
return_tensors="pt",
|
331 |
+
padding="longest",
|
332 |
+
truncation=True,
|
333 |
+
max_length=self.max_length,
|
334 |
+
add_special_tokens=False).to(
|
335 |
+
self.device)
|
336 |
+
|
337 |
+
targets = self.mask_human_targets(to_regress_tokens.input_ids)
|
338 |
+
targets = targets.to(self.device)
|
339 |
+
|
340 |
+
return to_regress_tokens, targets
|
341 |
+
|
342 |
+
def mask_human_targets(self, input_ids, pure=False):
|
343 |
+
target_batch = []
|
344 |
+
for bs in range(input_ids.shape[0]):
|
345 |
+
cur_idx = 0
|
346 |
+
ids = input_ids[bs]
|
347 |
+
targets = copy.deepcopy(ids)
|
348 |
+
last_eoa = 0
|
349 |
+
last_eoh = 0
|
350 |
+
for i, temp_id in enumerate(ids):
|
351 |
+
if temp_id == 103027: #### end of human
|
352 |
+
targets[cur_idx:i + 6] = -100
|
353 |
+
cur_idx = i + 6
|
354 |
+
last_eoh = i
|
355 |
+
elif temp_id == 103028: ### end of assistant
|
356 |
+
cur_idx = i + 1
|
357 |
+
last_eoa = i
|
358 |
+
elif temp_id == 2: ### eos and following pad
|
359 |
+
targets[i + 1:] = -100 #### loss on eos, but not on pad
|
360 |
+
break
|
361 |
+
if temp_id != 2 and last_eoa > last_eoh: ### trunction, end at last question
|
362 |
+
targets[last_eoa +
|
363 |
+
1:] = -100 #### mask all after the last answer
|
364 |
+
|
365 |
+
target_batch.append(targets.unsqueeze(0))
|
366 |
+
|
367 |
+
target_batch = torch.cat(target_batch, dim=0)
|
368 |
+
return target_batch
|
369 |
+
|
370 |
+
def forward(self,
|
371 |
+
input_ids=None,
|
372 |
+
attention_mask=None,
|
373 |
+
inputs_embeds=None,
|
374 |
+
labels=None,
|
375 |
+
output_attentions=None,
|
376 |
+
output_hidden_states=None,
|
377 |
+
return_dict=None,
|
378 |
+
**kwargs):
|
379 |
+
|
380 |
+
samples = kwargs.get('samples')
|
381 |
+
has_img = 'images' in samples.keys()
|
382 |
+
|
383 |
+
### encode text
|
384 |
+
text = self.align_text(samples, has_img=has_img)
|
385 |
+
to_regress_tokens, targets = self.text2emb(text)
|
386 |
+
|
387 |
+
to_regress_embeds = self.internlm_model.model.embed_tokens(
|
388 |
+
to_regress_tokens.input_ids)
|
389 |
+
attention_mask = to_regress_tokens.attention_mask
|
390 |
+
|
391 |
+
if has_img:
|
392 |
+
header = samples["text_input"][0].split(' <|User|>:')[0]
|
393 |
+
prompt = header + ' <|User|>:<ImageHere>'
|
394 |
+
|
395 |
+
### encode image
|
396 |
+
image = samples["image"]
|
397 |
+
img_embeds = self.encode_img(image)
|
398 |
+
img_embeds, atts_img, wrapped_target = self.prompt_wrap(
|
399 |
+
img_embeds, prompt)
|
400 |
+
### combine text and image
|
401 |
+
to_regress_embeds = torch.cat([img_embeds, to_regress_embeds],
|
402 |
+
dim=1)
|
403 |
+
attention_mask = torch.cat([atts_img, attention_mask], dim=1)
|
404 |
+
targets = torch.cat([wrapped_target, targets], dim=1)
|
405 |
+
|
406 |
+
outputs = self.internlm_model(
|
407 |
+
inputs_embeds=to_regress_embeds,
|
408 |
+
attention_mask=attention_mask,
|
409 |
+
return_dict=True,
|
410 |
+
labels=targets,
|
411 |
+
)
|
412 |
+
return outputs
|
modeling_utils.py
CHANGED
@@ -2,12 +2,12 @@ import logging
|
|
2 |
import math
|
3 |
import os
|
4 |
from contextlib import contextmanager
|
5 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
6 |
|
7 |
import timm.models.hub as timm_hub
|
8 |
import torch
|
9 |
import torch.distributed as dist
|
10 |
import torch.nn as nn
|
|
|
11 |
|
12 |
|
13 |
def is_dist_avail_and_initialized():
|
@@ -28,12 +28,16 @@ def is_main_process():
|
|
28 |
return get_rank() == 0
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
def download_cached_file(url, check_hash=True, progress=False):
|
32 |
"""
|
33 |
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
34 |
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
35 |
"""
|
36 |
-
|
37 |
def get_cached_file_path():
|
38 |
# a hack to sync the file path across processes
|
39 |
parts = torch.hub.urlparse(url)
|
@@ -76,18 +80,16 @@ def all_logging_disabled(highest_level=logging.CRITICAL):
|
|
76 |
|
77 |
|
78 |
class LoRALinear(nn.Linear):
|
79 |
-
def __init__(
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
**kwargs
|
90 |
-
) -> None:
|
91 |
super().__init__(in_features, out_features, bias, device, dtype)
|
92 |
self.lora_r = lora_r
|
93 |
self.lora_alpha = lora_alpha
|
@@ -97,12 +99,16 @@ class LoRALinear(nn.Linear):
|
|
97 |
self.lora_dropout = lambda x: x
|
98 |
self.lora_scaling = self.lora_alpha / self.lora_r
|
99 |
|
100 |
-
self.lora_A = nn.Linear(
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
|
107 |
self.reset_parameters()
|
108 |
|
@@ -116,7 +122,8 @@ class LoRALinear(nn.Linear):
|
|
116 |
orig_type = x.dtype
|
117 |
res = super().forward(x)
|
118 |
x = x.float()
|
119 |
-
res += self.lora_B(self.lora_A(
|
|
|
120 |
return res.to(orig_type)
|
121 |
|
122 |
|
@@ -127,7 +134,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
127 |
|
128 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
129 |
for stop in self.stops:
|
130 |
-
if torch.all((stop == input_ids[:, -len(stop)
|
131 |
return True
|
132 |
|
133 |
return False
|
|
|
2 |
import math
|
3 |
import os
|
4 |
from contextlib import contextmanager
|
|
|
5 |
|
6 |
import timm.models.hub as timm_hub
|
7 |
import torch
|
8 |
import torch.distributed as dist
|
9 |
import torch.nn as nn
|
10 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
11 |
|
12 |
|
13 |
def is_dist_avail_and_initialized():
|
|
|
28 |
return get_rank() == 0
|
29 |
|
30 |
|
31 |
+
def rank0_print(msg, **kwargs):
|
32 |
+
if is_main_process():
|
33 |
+
print(msg, **kwargs)
|
34 |
+
|
35 |
+
|
36 |
def download_cached_file(url, check_hash=True, progress=False):
|
37 |
"""
|
38 |
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
39 |
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
40 |
"""
|
|
|
41 |
def get_cached_file_path():
|
42 |
# a hack to sync the file path across processes
|
43 |
parts = torch.hub.urlparse(url)
|
|
|
80 |
|
81 |
|
82 |
class LoRALinear(nn.Linear):
|
83 |
+
def __init__(self,
|
84 |
+
in_features: int,
|
85 |
+
out_features: int,
|
86 |
+
bias: bool = True,
|
87 |
+
device=None,
|
88 |
+
dtype=None,
|
89 |
+
lora_r=8,
|
90 |
+
lora_alpha=16,
|
91 |
+
lora_dropout=0.05,
|
92 |
+
**kwargs) -> None:
|
|
|
|
|
93 |
super().__init__(in_features, out_features, bias, device, dtype)
|
94 |
self.lora_r = lora_r
|
95 |
self.lora_alpha = lora_alpha
|
|
|
99 |
self.lora_dropout = lambda x: x
|
100 |
self.lora_scaling = self.lora_alpha / self.lora_r
|
101 |
|
102 |
+
self.lora_A = nn.Linear(in_features,
|
103 |
+
self.lora_r,
|
104 |
+
bias=False,
|
105 |
+
device=device,
|
106 |
+
dtype=dtype)
|
107 |
+
self.lora_B = nn.Linear(self.lora_r,
|
108 |
+
out_features,
|
109 |
+
bias=False,
|
110 |
+
device=device,
|
111 |
+
dtype=dtype)
|
112 |
|
113 |
self.reset_parameters()
|
114 |
|
|
|
122 |
orig_type = x.dtype
|
123 |
res = super().forward(x)
|
124 |
x = x.float()
|
125 |
+
res += self.lora_B(self.lora_A(
|
126 |
+
self.lora_dropout(x))) * self.lora_scaling
|
127 |
return res.to(orig_type)
|
128 |
|
129 |
|
|
|
134 |
|
135 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
136 |
for stop in self.stops:
|
137 |
+
if torch.all((stop == input_ids[:, -len(stop):])).item():
|
138 |
return True
|
139 |
|
140 |
return False
|