Jonathan Tow
commited on
Commit
•
997d959
1
Parent(s):
1465d75
refactor: clean nn package access
Browse files
modeling_stablelm_alpha.py
CHANGED
@@ -112,7 +112,7 @@ class MLP(nn.Module):
|
|
112 |
ff_dim = int(8 * hidden_size / 3)
|
113 |
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
114 |
|
115 |
-
self.gate_proj =
|
116 |
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
117 |
self.act = nn.SiLU()
|
118 |
|
@@ -121,7 +121,7 @@ class MLP(nn.Module):
|
|
121 |
return self.out_proj(ff * self.act(ff_gate))
|
122 |
|
123 |
|
124 |
-
class RotaryEmbedding(
|
125 |
def __init__(
|
126 |
self,
|
127 |
dim: int,
|
|
|
112 |
ff_dim = int(8 * hidden_size / 3)
|
113 |
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
114 |
|
115 |
+
self.gate_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
116 |
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
117 |
self.act = nn.SiLU()
|
118 |
|
|
|
121 |
return self.out_proj(ff * self.act(ff_gate))
|
122 |
|
123 |
|
124 |
+
class RotaryEmbedding(nn.Module):
|
125 |
def __init__(
|
126 |
self,
|
127 |
dim: int,
|