kaizen9 commited on
Commit
0ab4ea5
1 Parent(s): 7cc3e06

Update attention.py

Browse files
Files changed (1) hide show
  1. attention.py +2 -2
attention.py CHANGED
@@ -324,7 +324,7 @@ class MultiheadAttention(GroupedQueryAttention):
324
  additive bias.
325
  """
326
 
327
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='te', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
328
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
329
 
330
  class MultiQueryAttention(GroupedQueryAttention):
@@ -333,7 +333,7 @@ class MultiQueryAttention(GroupedQueryAttention):
333
  additive bias.
334
  """
335
 
336
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='te', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
337
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
338
 
339
  def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
 
324
  additive bias.
325
  """
326
 
327
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
328
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=n_heads, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
329
 
330
  class MultiQueryAttention(GroupedQueryAttention):
 
333
  additive bias.
334
  """
335
 
336
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, qk_gn: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, bias: bool=True, sliding_window_size: int=-1):
337
  super().__init__(d_model=d_model, n_heads=n_heads, kv_n_heads=1, attn_impl=attn_impl, clip_qkv=clip_qkv, qk_ln=qk_ln, qk_gn=qk_gn, softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, device=device, bias=bias, sliding_window_size=sliding_window_size)
338
 
339
  def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: