Update modeling_qwen_yarn.py
Browse files- modeling_qwen_yarn.py +4 -0
modeling_qwen_yarn.py
CHANGED
@@ -1156,6 +1156,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1156 |
output = (lm_logits,) + transformer_outputs[1:]
|
1157 |
return ((loss,) + output) if loss is not None else output
|
1158 |
|
|
|
|
|
|
|
|
|
1159 |
return CausalLMOutputWithPast(
|
1160 |
loss=loss,
|
1161 |
logits=lm_logits,
|
|
|
1156 |
output = (lm_logits,) + transformer_outputs[1:]
|
1157 |
return ((loss,) + output) if loss is not None else output
|
1158 |
|
1159 |
+
#训练时节约显存
|
1160 |
+
# if self.training:
|
1161 |
+
# lm_logits=None
|
1162 |
+
|
1163 |
return CausalLMOutputWithPast(
|
1164 |
loss=loss,
|
1165 |
logits=lm_logits,
|