njeffrie commited on
Commit
40d1a51
1 Parent(s): 1e12629

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MoonshineModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_moonshine.MoonshineConfig",
7
+ "AutoModelForCausalLM": "modeling_moonshine.MoonshineModel"
8
+ },
9
+ "dec_depth": 6,
10
+ "dec_ff_swiglu": true,
11
+ "dec_voc_size": 32768,
12
+ "dim": 288,
13
+ "enc_depth": 6,
14
+ "enc_ff_swiglu": false,
15
+ "inner_dim": 288,
16
+ "model_type": "moonshine",
17
+ "n_head": 8,
18
+ "torch_dtype": "float32",
19
+ "transformers_version": "4.46.1"
20
+ }
configuration_moonshine.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class MoonshineConfig(PretrainedConfig):
6
+ model_type = "moonshine"
7
+
8
+ def __init__(
9
+ self,
10
+ dim: int = 288,
11
+ inner_dim: int = None,
12
+ enc_depth: int = 8,
13
+ dec_depth: int = 8,
14
+ n_head: int = 8,
15
+ dec_voc_size: int = 32768,
16
+ enc_ff_swiglu: bool = False,
17
+ dec_ff_swiglu: bool = True,
18
+ **kwargs
19
+ ):
20
+ if inner_dim is None:
21
+ inner_dim = dim
22
+ if inner_dim % n_head != 0:
23
+ raise ValueError("`inner dim` must be divisible by `n_head`")
24
+ self.dim = dim
25
+ self.inner_dim = inner_dim
26
+ self.enc_depth = enc_depth
27
+ self.dec_depth = dec_depth
28
+ self.n_head = n_head
29
+ self.dec_voc_size = dec_voc_size
30
+ self.enc_ff_swiglu = enc_ff_swiglu
31
+ self.dec_ff_swiglu = dec_ff_swiglu
32
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f46496c082ab898f5414e31bae398953aa205fb5fc614eb8be7f0d8d8ddd0aa
3
+ size 186049168
modeling_moonshine.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from einops.layers.torch import Rearrange
3
+ from torch import nn
4
+ from transformers import PreTrainedModel
5
+
6
+ import math
7
+ import torch
8
+
9
+ from .configuration_moonshine import MoonshineConfig
10
+
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ def __init__(self, dim, base=10000):
14
+ super().__init__()
15
+
16
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
17
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
18
+
19
+ def forward(self, t):
20
+ freqs = torch.einsum("i , j -> i j", t.type_as(self.inv_freq), self.inv_freq)
21
+ freqs = torch.stack((freqs, freqs), dim=-1)
22
+ return rearrange(freqs, "... d r -> ... (d r)")
23
+
24
+
25
+ def rotate_half(x):
26
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
27
+ x1, x2 = x.unbind(dim=-1)
28
+ x = torch.stack((-x2, x1), dim=-1)
29
+ return rearrange(x, "... d r -> ... (d r)")
30
+
31
+
32
+ def apply_rotary_pos_emb(t, freqs):
33
+ rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
34
+
35
+ freqs = freqs[-seq_len:, :]
36
+
37
+ # partial rotary embeddings, Wang et al. GPT-J
38
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
39
+ t = t * freqs.cos() + rotate_half(t) * freqs.sin()
40
+ out = torch.cat((t, t_unrotated), dim=-1)
41
+
42
+ return out.type(orig_dtype)
43
+
44
+
45
+ class MultiHeadAttention(nn.Module):
46
+ def __init__(self, dim, inner_dim, n_head):
47
+ super().__init__()
48
+ self.n_head = n_head
49
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
50
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
51
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
52
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
53
+ self.softmax = nn.Softmax(dim=-1)
54
+
55
+ # Scaled dot product attention
56
+ def sdp_attention(self, q, k_t, v, mask=None):
57
+ d_tensor = v.shape[3]
58
+
59
+ op = (q @ k_t) / math.sqrt(d_tensor)
60
+ if mask is not None:
61
+ op = op.masked_fill(mask, -torch.finfo(op.dtype).max)
62
+ score = self.softmax(op)
63
+ out = score @ v
64
+
65
+ # concat and pass to linear layer
66
+ out = rearrange(out, "b h n d -> b n (h d)")
67
+ return self.to_out(out)
68
+
69
+ def forward(self, q, k, v, rot_pos_emb=None, mask=None):
70
+ # dot product with weight matrices
71
+ q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
72
+
73
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
74
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
75
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
76
+
77
+ # apply RoPE
78
+ if rot_pos_emb is not None:
79
+ q = apply_rotary_pos_emb(q, rot_pos_emb)
80
+ k = apply_rotary_pos_emb(k, rot_pos_emb)
81
+
82
+ k_t = k.transpose(2, 3)
83
+
84
+ return self.sdp_attention(q, k_t, v, mask), k_t, v
85
+
86
+
87
+ class MultiHeadCausalSelfAttentionWithKVCache(MultiHeadAttention):
88
+ def __init__(self, dim, inner_dim, n_head):
89
+ super().__init__(dim, inner_dim, n_head)
90
+
91
+ def forward(self, q, k, v, k_cache, v_cache, rot_pos_emb, mask):
92
+ # dot product with weight matrices
93
+ q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
94
+
95
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
96
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.n_head)
97
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.n_head)
98
+
99
+ # apply RoPE
100
+ q = apply_rotary_pos_emb(q, rot_pos_emb)
101
+ k = apply_rotary_pos_emb(k, rot_pos_emb)
102
+
103
+ k_t = k.transpose(2, 3)
104
+
105
+ # Append new rows to K and V caches.
106
+ k_t = torch.concat((k_cache, k_t), dim=3)
107
+ v = torch.concat((v_cache, v), dim=2)
108
+
109
+ return super().sdp_attention(q, k_t, v, mask=mask), k_t, v
110
+
111
+
112
+ class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
113
+ def __init__(self, dim, inner_dim, n_head):
114
+ super().__init__(dim, inner_dim, n_head)
115
+
116
+ def forward(self, q, k_cache, v_cache):
117
+ q = self.to_q(q)
118
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
119
+
120
+ return super().sdp_attention(q, k_cache, v_cache)
121
+
122
+
123
+ class FFLinearGelu(nn.Module):
124
+ def __init__(self, dim, ff_mult=4):
125
+ super().__init__()
126
+
127
+ self.ff = nn.Sequential(
128
+ nn.Linear(dim, dim * ff_mult, bias=True),
129
+ nn.GELU(),
130
+ nn.Linear(dim * ff_mult, dim, bias=True),
131
+ )
132
+
133
+ def forward(self, x):
134
+ return self.ff(x)
135
+
136
+
137
+ class FFSwiGLU(nn.Module):
138
+ def __init__(self, dim, ff_mult=4):
139
+ super().__init__()
140
+
141
+ self.ff_proj = nn.Linear(dim, dim * ff_mult, bias=True)
142
+ self.ff_noact = nn.Linear(dim, dim * ff_mult, bias=True)
143
+ self.ff_act = nn.SiLU()
144
+ self.ff_out = nn.Linear(dim * ff_mult, dim, bias=True)
145
+
146
+ def forward(self, x):
147
+ gate = self.ff_act(self.ff_proj(x))
148
+ x_noact = self.ff_noact(x)
149
+ x = x_noact * gate
150
+ return self.ff_out(x)
151
+
152
+
153
+ class EncoderLayer(nn.Module):
154
+ def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
155
+ super().__init__()
156
+
157
+ self.norm1 = nn.LayerNorm(dim, bias=False)
158
+
159
+ self.attention = MultiHeadAttention(dim, inner_dim=inner_dim, n_head=n_head)
160
+
161
+ self.norm2 = nn.LayerNorm(dim, bias=False)
162
+
163
+ self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
164
+
165
+ def forward(self, x, rot_pos_emb):
166
+ _x = x
167
+ x = self.norm1(x)
168
+ x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb)
169
+ x = x + _x
170
+
171
+ _x = x
172
+ x = self.norm2(x)
173
+ x = self.ff(x)
174
+
175
+ x = x + _x
176
+ return x
177
+
178
+
179
+ class Encoder(nn.Module):
180
+ def __init__(self, dim, inner_dim, n_head, n_layers, ff_swiglu):
181
+ super().__init__()
182
+ rot_embed_dim = max(inner_dim / n_head / 2, 32)
183
+ self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
184
+
185
+ self.layers = nn.ModuleList(
186
+ [EncoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
187
+ )
188
+ self.post_norm = nn.LayerNorm(dim, bias=False)
189
+
190
+ def forward(self, x):
191
+ pos = torch.arange(x.shape[1], device=x.device)
192
+ rot_pos_emb = self.rot_pos_emb(pos)
193
+
194
+ for layer in self.layers:
195
+ x = layer(x, rot_pos_emb=rot_pos_emb)
196
+ return self.post_norm(x)
197
+
198
+
199
+ class DecoderLayer(nn.Module):
200
+ def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
201
+ super().__init__()
202
+
203
+ self.norm1 = nn.LayerNorm(dim, bias=False)
204
+
205
+ self.self_attention = MultiHeadCausalSelfAttentionWithKVCache(
206
+ dim, inner_dim=inner_dim, n_head=n_head
207
+ )
208
+
209
+ self.norm2 = nn.LayerNorm(dim, bias=False)
210
+ self.cross_attention = MultiHeadCrossAttentionWithKVCache(
211
+ dim, inner_dim=inner_dim, n_head=n_head
212
+ )
213
+
214
+ self.norm3 = nn.LayerNorm(dim, bias=False)
215
+ self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
216
+
217
+ def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb):
218
+ dim = x.size()[1]
219
+ causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
220
+ _x = x
221
+ x = self.norm1(x)
222
+ x, new_k_cache, new_v_cache = self.self_attention(
223
+ q=x,
224
+ k=x,
225
+ v=x,
226
+ k_cache=k_cache,
227
+ v_cache=v_cache,
228
+ rot_pos_emb=rot_pos_emb,
229
+ mask=causal_mask,
230
+ )
231
+ x = x + _x
232
+
233
+ _x = x
234
+ x = self.norm2(x)
235
+ x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache)
236
+ x = x + _x
237
+
238
+ _x = x
239
+ x = self.norm3(x)
240
+ x = self.ff(x)
241
+ x = x + _x
242
+
243
+ return x, new_k_cache, new_v_cache
244
+
245
+
246
+ class Decoder(nn.Module):
247
+ def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
248
+ super().__init__()
249
+
250
+ self.n_head = n_head
251
+ self.d_head = inner_dim // n_head
252
+
253
+ rot_embed_dim = max(inner_dim / n_head / 2, 32)
254
+ self.rot_pos_emb = RotaryEmbedding(rot_embed_dim)
255
+
256
+ self.layers = nn.ModuleList(
257
+ [DecoderLayer(dim, inner_dim, n_head, ff_swiglu) for _ in range(n_layers)]
258
+ )
259
+ self.final_norm = nn.LayerNorm(dim, bias=False)
260
+ self.token_embedding = nn.Embedding(dec_voc_size, dim)
261
+
262
+ def forward(self, x, *args):
263
+ pos = torch.arange(x.shape[1], device=x.device)
264
+ rot_pos_emb = self.rot_pos_emb(pos)
265
+ x = self.token_embedding(x)
266
+
267
+ k_cache_new = []
268
+ v_cache_new = []
269
+
270
+ n_layer = len(self.layers)
271
+ k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
272
+ args[i : i + n_layer] for i in range(0, 4 * n_layer, n_layer)
273
+ ]
274
+ for idx, layer in enumerate(self.layers):
275
+ x, new_k_line, new_v_line = layer(
276
+ x[:, -1:],
277
+ k_cache=k_cache[idx],
278
+ v_cache=v_cache[idx],
279
+ x_attn_k_cache=x_attn_k_cache[idx],
280
+ x_attn_v_cache=x_attn_v_cache[idx],
281
+ rot_pos_emb=rot_pos_emb,
282
+ )
283
+ k_cache_new.append(new_k_line)
284
+ v_cache_new.append(new_v_line)
285
+
286
+ x = self.final_norm(x)
287
+
288
+ return x @ self.token_embedding.weight.t(), *k_cache_new, *v_cache_new
289
+
290
+
291
+ class InitialDecoderLayer(nn.Module):
292
+ def __init__(self, dim, inner_dim, n_head, ff_swiglu, ff_mult=4):
293
+ super().__init__()
294
+
295
+ self.norm1 = nn.LayerNorm(dim, bias=False)
296
+
297
+ self.self_attention = MultiHeadAttention(
298
+ dim, inner_dim=inner_dim, n_head=n_head
299
+ )
300
+
301
+ self.norm2 = nn.LayerNorm(dim, bias=False)
302
+ self.cross_attention = MultiHeadAttention(
303
+ dim, inner_dim=inner_dim, n_head=n_head
304
+ )
305
+
306
+ self.norm3 = nn.LayerNorm(dim, bias=False)
307
+ self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
308
+
309
+ def forward(self, x, context, rot_pos_emb):
310
+ dim = x.size()[1]
311
+ causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
312
+ _x = x
313
+ x = self.norm1(x)
314
+ x, new_k_cache, new_v_cache = self.self_attention(
315
+ q=x,
316
+ k=x,
317
+ v=x,
318
+ rot_pos_emb=rot_pos_emb,
319
+ mask=causal_mask,
320
+ )
321
+ x = x + _x
322
+
323
+ _x = x
324
+ x = self.norm2(x)
325
+ x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
326
+ q=x, k=context, v=context
327
+ )
328
+ x = x + _x
329
+
330
+ _x = x
331
+ x = self.norm3(x)
332
+ x = self.ff(x)
333
+ x = x + _x
334
+
335
+ return x, new_k_cache, new_v_cache, x_attn_k_cache, x_attn_v_cache
336
+
337
+
338
+ class DecoderInitial(Decoder):
339
+ def __init__(self, dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu):
340
+ super().__init__(dim, inner_dim, n_head, n_layers, dec_voc_size, ff_swiglu)
341
+ self.layers = nn.ModuleList(
342
+ [
343
+ InitialDecoderLayer(dim, inner_dim, n_head, ff_swiglu)
344
+ for _ in range(n_layers)
345
+ ]
346
+ )
347
+
348
+ def forward(self, x, enc_src):
349
+ pos = torch.arange(x.shape[1], device=x.device)
350
+ rot_pos_emb = self.rot_pos_emb(pos)
351
+ x = self.token_embedding(x)
352
+
353
+ # Shape [n_layers, batch_size, n_head, seq_len, inner_dim]. Cache K transposed.
354
+ n_layer = len(self.layers)
355
+ k_cache = []
356
+ v_cache = []
357
+ x_attn_k_cache = []
358
+ x_attn_v_cache = []
359
+
360
+ for idx, layer in enumerate(self.layers):
361
+ x, new_k_line, new_v_line, new_x_attn_k_line, new_x_attn_v_line = layer(
362
+ x,
363
+ enc_src,
364
+ rot_pos_emb,
365
+ )
366
+
367
+ k_cache.append(new_k_line)
368
+ v_cache.append(new_v_line)
369
+ x_attn_k_cache.append(new_x_attn_k_line)
370
+ x_attn_v_cache.append(new_x_attn_v_line)
371
+
372
+ x = self.final_norm(x)
373
+
374
+ return (
375
+ x @ self.token_embedding.weight.t(),
376
+ *k_cache,
377
+ *v_cache,
378
+ *x_attn_k_cache,
379
+ *x_attn_v_cache,
380
+ )
381
+
382
+
383
+ class AudioPreprocessor(nn.Module):
384
+ def __init__(self, dim):
385
+ super().__init__()
386
+ self.audio_preprocess = nn.Sequential(
387
+ nn.Conv1d(1, dim, 127, 64, bias=False),
388
+ nn.Tanh(),
389
+ nn.GroupNorm(1, dim),
390
+ nn.Conv1d(dim, 2 * dim, 7, 3),
391
+ nn.GELU(),
392
+ nn.Conv1d(2 * dim, dim, 3, 2),
393
+ nn.GELU(),
394
+ Rearrange("... c s -> ... s c"),
395
+ )
396
+
397
+ def forward(self, src):
398
+ assert (
399
+ src.shape[-1] >= 1023
400
+ ), f"src shape[-1] {src.shape[-1]} should be at least 1023"
401
+ src = src.unsqueeze(-2)
402
+ return self.audio_preprocess(src)
403
+
404
+
405
+ class MoonshineModelTorch(nn.Module):
406
+ def __init__(
407
+ self,
408
+ dim,
409
+ inner_dim,
410
+ enc_depth,
411
+ dec_depth,
412
+ n_head=8,
413
+ dec_voc_size=32768,
414
+ enc_ff_swiglu=False,
415
+ dec_ff_swiglu=False,
416
+ ):
417
+ super().__init__()
418
+ self.preprocessor = AudioPreprocessor(dim)
419
+ self.encoder = Encoder(
420
+ dim, inner_dim, n_head, enc_depth, ff_swiglu=enc_ff_swiglu
421
+ )
422
+ self.decoder_initial = DecoderInitial(
423
+ dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
424
+ )
425
+ self.decoder = Decoder(
426
+ dim, inner_dim, n_head, dec_depth, dec_voc_size, ff_swiglu=dec_ff_swiglu
427
+ )
428
+ self.dec_depth = dec_depth
429
+ self.n_head = n_head
430
+ self.d_head = inner_dim // n_head
431
+
432
+ def generate(self, src):
433
+ preprocessed = self.preprocessor(src)
434
+ enc = self.encoder(preprocessed)
435
+ sot_token = 1
436
+ eot_token = 2
437
+
438
+ seq = torch.as_tensor([[sot_token]]).to(src.device)
439
+
440
+ vals = self.decoder_initial(x=seq, enc_src=enc)
441
+ logits = vals[0]
442
+ k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
443
+ vals[i : i + self.dec_depth]
444
+ for i in range(1, 1 + self.dec_depth * 4, self.dec_depth)
445
+ ]
446
+
447
+ sample = logits[:, -1].argmax(dim=-1, keepdim=True)
448
+ seq = torch.cat((seq, sample), dim=-1)
449
+
450
+ seq_len = int(src.shape[-1] * 6 / 16000)
451
+ while sample != eot_token and len(seq.flatten()) <= seq_len:
452
+ vals = self.decoder(
453
+ seq,
454
+ *k_cache,
455
+ *v_cache,
456
+ *x_attn_k_cache,
457
+ *x_attn_v_cache,
458
+ )
459
+ logits = vals[0]
460
+ k_cache = vals[1 : self.dec_depth + 1]
461
+ v_cache = vals[self.dec_depth + 1 :]
462
+ logits = logits[:, -1] # get last token
463
+ sample = logits.argmax(dim=-1, keepdim=True)
464
+ seq = torch.cat((seq, sample), dim=-1)
465
+
466
+ return seq
467
+
468
+
469
+ class MoonshineModel(PreTrainedModel):
470
+ config_class = MoonshineConfig
471
+
472
+ def __init__(self, config):
473
+ super().__init__(config)
474
+ self.model = MoonshineModelTorch(
475
+ dim = config.dim,
476
+ inner_dim = config.inner_dim,
477
+ enc_depth = config.enc_depth,
478
+ dec_depth = config.dec_depth,
479
+ n_head = config.n_head,
480
+ dec_voc_size = config.dec_voc_size,
481
+ enc_ff_swiglu = config.enc_ff_swiglu,
482
+ dec_ff_swiglu = config.dec_ff_swiglu,
483
+ )
484
+
485
+ def forward(self, tensor):
486
+ return self.model.generate(tensor)