jupyterjazz
commited on
Commit
•
d9d8306
1
Parent(s):
4e13c90
fix: residual is kept in kwargs
Browse filesSigned-off-by: jupyterjazz <[email protected]>
mha.py
CHANGED
@@ -649,7 +649,8 @@ class MHA(nn.Module):
|
|
649 |
if not self.return_residual:
|
650 |
qkv = self.Wqkv(x, **lora_kwargs)
|
651 |
else:
|
652 |
-
lora_kwargs
|
|
|
653 |
qkv, x = self.Wqkv(x, **lora_kwargs)
|
654 |
|
655 |
if self.dwconv:
|
@@ -737,5 +738,6 @@ class MHA(nn.Module):
|
|
737 |
else:
|
738 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
739 |
|
|
|
740 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
|
741 |
return out if not self.return_residual else (out, x)
|
|
|
649 |
if not self.return_residual:
|
650 |
qkv = self.Wqkv(x, **lora_kwargs)
|
651 |
else:
|
652 |
+
if lora_kwargs:
|
653 |
+
lora_kwargs['residual'] = True
|
654 |
qkv, x = self.Wqkv(x, **lora_kwargs)
|
655 |
|
656 |
if self.dwconv:
|
|
|
738 |
else:
|
739 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
740 |
|
741 |
+
lora_kwargs.pop('residual', None)
|
742 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
|
743 |
return out if not self.return_residual else (out, x)
|