mrsteyk commited on
Commit
cb65bb0
1 Parent(s): 0865e02

Upload 2 files

Browse files

Who's getting the best head?

Files changed (2) hide show
  1. model.py +548 -0
  2. train.py +317 -0
model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import os, math, gc
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
11
+ from pytorch_lightning.strategies import DeepSpeedStrategy
12
+ import deepspeed
13
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
14
+
15
+ # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
16
+
17
+
18
+ def __nop(ob):
19
+ return ob
20
+
21
+
22
+ MyModule = nn.Module
23
+ MyFunction = __nop
24
+ if os.environ["RWKV_JIT_ON"] == "1":
25
+ MyModule = torch.jit.ScriptModule
26
+ MyFunction = torch.jit.script_method
27
+
28
+
29
+ ########################################################################################################
30
+ # CUDA Kernel
31
+ ########################################################################################################
32
+
33
+ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
34
+ # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
35
+
36
+ from torch.utils.cpp_extension import load
37
+
38
+ wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
39
+
40
+
41
+ class WKV(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, B, T, C, w, u, k, v):
44
+ ctx.B = B
45
+ ctx.T = T
46
+ ctx.C = C
47
+ assert T <= T_MAX
48
+ assert B * C % min(C, 32) == 0
49
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
50
+ w = -torch.exp(w.contiguous())
51
+ u = u.contiguous()
52
+ k = k.contiguous()
53
+ v = v.contiguous()
54
+ else:
55
+ w = -torch.exp(w.float().contiguous())
56
+ u = u.float().contiguous()
57
+ k = k.float().contiguous()
58
+ v = v.float().contiguous()
59
+ ctx.save_for_backward(w, u, k, v)
60
+ y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
61
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
62
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
63
+ return y
64
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
65
+ return y.half()
66
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
67
+ return y.bfloat16()
68
+
69
+ @staticmethod
70
+ def backward(ctx, gy):
71
+ B = ctx.B
72
+ T = ctx.T
73
+ C = ctx.C
74
+ assert T <= T_MAX
75
+ assert B * C % min(C, 32) == 0
76
+ w, u, k, v = ctx.saved_tensors
77
+ gw = torch.zeros((B, C), device=gy.device).contiguous()
78
+ gu = torch.zeros((B, C), device=gy.device).contiguous()
79
+ gk = torch.zeros((B, T, C), device=gy.device).contiguous()
80
+ gv = torch.zeros((B, T, C), device=gy.device).contiguous()
81
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
82
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
83
+ else:
84
+ wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
85
+ gw = torch.sum(gw, dim=0)
86
+ gu = torch.sum(gu, dim=0)
87
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
88
+ return (None, None, None, gw, gu, gk, gv)
89
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
90
+ return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
91
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
92
+ return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
93
+
94
+
95
+ def RUN_CUDA(B, T, C, w, u, k, v):
96
+ return WKV.apply(B, T, C, w, u, k, v)
97
+
98
+
99
+ ########################################################################################################
100
+ # RWKV: RWKV Time-mix + RWKV Channel-mix
101
+ ########################################################################################################
102
+
103
+
104
+ class RWKV_TimeMix(MyModule):
105
+ def __init__(self, args, layer_id):
106
+ super().__init__()
107
+ self.args = args
108
+ self.layer_id = layer_id
109
+ self.ctx_len = args.ctx_len
110
+ self.n_embd = args.n_embd
111
+ self.my_testing = self.args.my_testing
112
+ attn_sz = args.n_embd
113
+
114
+ with torch.no_grad(): # fancy init
115
+ ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
116
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
117
+
118
+ # fancy time_decay
119
+ decay_speed = torch.ones(attn_sz)
120
+ for h in range(attn_sz):
121
+ decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
122
+ self.time_decay = nn.Parameter(decay_speed)
123
+ # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
124
+
125
+ # fancy time_first
126
+ zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
127
+ self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
128
+
129
+ # fancy time_mix
130
+ x = torch.ones(1, 1, args.n_embd)
131
+ for i in range(args.n_embd):
132
+ x[0, 0, i] = i / args.n_embd
133
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
134
+ self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
135
+ self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
136
+
137
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
138
+
139
+ self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
140
+ self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
141
+ self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
142
+
143
+ self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
144
+
145
+ # if self.my_testing > 0:
146
+ # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
147
+
148
+ @MyFunction
149
+ def jit_func(self, x):
150
+
151
+ # Mix x with the previous timestep to produce xk, xv, xr
152
+ xx = self.time_shift(x)
153
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
154
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
155
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
156
+
157
+ # Use xk, xv, xr to produce k, v, r
158
+ k = self.key(xk)
159
+ v = self.value(xv)
160
+ r = self.receptance(xr)
161
+ sr = torch.sigmoid(r)
162
+
163
+ return sr, k, v
164
+
165
+ def forward(self, x):
166
+ B, T, C = x.size() # x = (Batch,Time,Channel)
167
+
168
+ sr, k, v = self.jit_func(x)
169
+
170
+ rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
171
+ rwkv = self.output(rwkv)
172
+ return rwkv
173
+
174
+
175
+ class RWKV_ChannelMix(MyModule):
176
+ def __init__(self, args, layer_id):
177
+ super().__init__()
178
+ self.args = args
179
+ self.layer_id = layer_id
180
+ self.my_testing = self.args.my_testing
181
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
182
+
183
+ with torch.no_grad(): # fancy init of time_mix
184
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
185
+
186
+ x = torch.ones(1, 1, args.n_embd)
187
+ for i in range(args.n_embd):
188
+ x[0, 0, i] = i / args.n_embd
189
+
190
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
191
+ self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
192
+
193
+ hidden_sz = 4 * args.n_embd
194
+ self.key = nn.Linear(args.n_embd, hidden_sz, bias=False)
195
+ self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
196
+ self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
197
+
198
+ # if self.my_testing in [1]:
199
+ # self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
200
+ # elif self.my_testing in [2]:
201
+ # self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
202
+
203
+
204
+ @MyFunction
205
+ def forward(self, x):
206
+ xx = self.time_shift(x)
207
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
208
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
209
+
210
+ k = self.key(xk)
211
+ k = torch.square(torch.relu(k))
212
+ kv = self.value(k)
213
+
214
+ rkv = torch.sigmoid(self.receptance(xr)) * kv
215
+ return rkv
216
+
217
+ # k = self.key(xk)
218
+ # # if self.my_testing in [0, 2]:
219
+ # k = torch.square(torch.relu(k))
220
+ # # elif self.my_testing == 1:
221
+ # # k = torch.square(torch.relu(k)) + k * self.aaa
222
+ # kv = self.value(k)
223
+ # r = self.receptance(xr)
224
+ # # if self.my_testing == 0:
225
+ # r = torch.sigmoid(r)
226
+ # # elif self.my_testing == 2:
227
+ # # r = torch.sigmoid(r) + r * self.aaa
228
+ # rkv = r * kv
229
+ # return rkv
230
+
231
+ ########################################################################################################
232
+ # The RWKV Model with our blocks
233
+ ########################################################################################################
234
+
235
+
236
+ class Block(nn.Module):
237
+ def __init__(self, args, layer_id):
238
+ super().__init__()
239
+ self.args = args
240
+ self.layer_id = layer_id
241
+
242
+ self.ln1 = nn.LayerNorm(args.n_embd)
243
+ self.ln2 = nn.LayerNorm(args.n_embd)
244
+
245
+ if self.layer_id == 0:
246
+ self.ln0 = nn.LayerNorm(args.n_embd)
247
+ if args.my_pos_emb > 0:
248
+ self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
249
+ self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
250
+
251
+ if self.layer_id == 0 and self.args.pre_ffn > 0:
252
+ self.ffnPre = RWKV_ChannelMix(args, 0)
253
+ else:
254
+ self.att = RWKV_TimeMix(args, layer_id)
255
+
256
+ self.ffn = RWKV_ChannelMix(args, layer_id)
257
+
258
+ if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
259
+ self.tiny_ln = nn.LayerNorm(args.n_embd)
260
+ self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
261
+ self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
262
+ self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
263
+ self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
264
+
265
+ def forward(self, x, x_emb=None):
266
+ args = self.args
267
+ B, T, C = x.size()
268
+ if self.layer_id == 0:
269
+ x = self.ln0(x)
270
+ if args.my_pos_emb > 0:
271
+ pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
272
+ x = x + pos_emb
273
+
274
+ if self.layer_id == 0 and args.pre_ffn > 0:
275
+ x = x + self.ffnPre(self.ln1(x))
276
+ else:
277
+ x = x + self.att(self.ln1(x))
278
+ x = x + self.ffn(self.ln2(x))
279
+
280
+ if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
281
+ xx = self.tiny_ln(x)
282
+ q = self.tiny_q(xx)[:, :T, :]
283
+ k = self.tiny_k(xx)[:, :T, :]
284
+ c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
285
+ c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
286
+ x = x + c @ self.tiny_v(x_emb)
287
+ return x
288
+
289
+
290
+ class L2Wrap(torch.autograd.Function):
291
+ @staticmethod
292
+ def forward(ctx, loss, y):
293
+ ctx.save_for_backward(y)
294
+ return loss
295
+
296
+ @staticmethod
297
+ def backward(ctx, grad_output):
298
+ y = ctx.saved_tensors[0]
299
+ # to encourage the logits to be close to 0
300
+ factor = 1e-4 / (y.shape[0] * y.shape[1])
301
+ maxx, ids = torch.max(y, -1, keepdim=True)
302
+ gy = torch.zeros_like(y)
303
+ gy.scatter_(-1, ids, maxx * factor)
304
+ return (grad_output, gy)
305
+
306
+
307
+ class RWKV(pl.LightningModule):
308
+ def __init__(self, args):
309
+ super().__init__()
310
+ self.args = args
311
+
312
+ self.emb = nn.Embedding(args.vocab_size, args.n_embd)
313
+
314
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
315
+
316
+ self.ln_out = nn.LayerNorm(args.n_embd)
317
+ self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
318
+
319
+ if args.head_qk > 0:
320
+ self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
321
+ self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
322
+ self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
323
+
324
+ def resize_emb(self, new_tokens: int):
325
+ print(f"### RESIZING MODEL TO {new_tokens} TOKENS ###")
326
+
327
+ new_embed = nn.Embedding(new_tokens, self.args.n_embd)
328
+ new_embed.to(self.emb.weight.device, dtype=self.emb.weight.dtype)
329
+ nn.init.zeros_(new_embed.weight)
330
+
331
+ n = min(self.args.vocab_size, new_tokens)
332
+ print("### Start emb copy", new_embed.weight.size(), self.emb.weight.size())
333
+ new_embed.weight.data[:n, :] = self.emb.weight.data[:n, :]
334
+ self.emb = new_embed
335
+ print("### emb copy end")
336
+
337
+ # Now we resize head
338
+ new_head = nn.Linear(self.args.n_embd, new_tokens, bias=False)
339
+ new_head.to(self.head.weight.device, dtype=self.head.weight.dtype)
340
+ nn.init.orthogonal_(new_head.weight, gain=1 * 0.5)
341
+
342
+ print("### Start head copy", new_head.weight.size(), self.head.weight.size())
343
+ new_head.weight.data[:n, :] = self.head.weight.data[:n, :]
344
+ self.head = new_head
345
+ print("### RESIZE END")
346
+
347
+ def configure_optimizers(self):
348
+ args = self.args
349
+ if args.layerwise_lr > 0:
350
+ lr_1x = set()
351
+ lr_2x = set()
352
+ lr_3x = set()
353
+ for n, p in self.named_parameters():
354
+ if "time_mix" in n:
355
+ if args.my_pile_stage == 2:
356
+ lr_2x.add(n)
357
+ else:
358
+ lr_1x.add(n)
359
+ elif "time_decay" in n:
360
+ if args.my_pile_stage == 2:
361
+ lr_3x.add(n)
362
+ else:
363
+ lr_2x.add(n)
364
+ elif "time_first" in n:
365
+ lr_3x.add(n)
366
+ else:
367
+ lr_1x.add(n)
368
+ lr_1x = sorted(list(lr_1x))
369
+ lr_2x = sorted(list(lr_2x))
370
+ lr_3x = sorted(list(lr_3x))
371
+ # print('1x', lr_1x)
372
+ # print('2x', lr_2x)
373
+ # print('3x', lr_3x)
374
+ param_dict = {n: p for n, p in self.named_parameters()}
375
+ if args.my_pile_stage == 2:
376
+ optim_groups = [
377
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
378
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
379
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
380
+ ]
381
+ else:
382
+ optim_groups = [
383
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
384
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
385
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
386
+ ]
387
+ else:
388
+ optim_groups = [
389
+ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
390
+ ]
391
+
392
+ if self.deepspeed_offload:
393
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
394
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
395
+ # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
396
+
397
+ @property
398
+ def deepspeed_offload(self) -> bool:
399
+ strategy = self.trainer.strategy
400
+ if isinstance(strategy, DeepSpeedStrategy):
401
+ cfg = strategy.config["zero_optimization"]
402
+ return cfg.get("offload_optimizer") or cfg.get("offload_param")
403
+ return False
404
+
405
+ def forward(self, idx):
406
+ args = self.args
407
+ B, T = idx.size()
408
+ assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
409
+
410
+ x = self.emb(idx)
411
+ x_emb = x
412
+
413
+ if args.tiny_att_dim > 0:
414
+ for block in self.blocks:
415
+ if args.grad_cp == 1:
416
+ x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
417
+ else:
418
+ x = block(x, x_emb)
419
+ else:
420
+ for block in self.blocks:
421
+ if args.grad_cp == 1:
422
+ x = deepspeed.checkpointing.checkpoint(block, x)
423
+ else:
424
+ x = block(x)
425
+
426
+ x = self.ln_out(x)
427
+
428
+ if args.head_qk > 0:
429
+ q = self.head_q(x)[:, :T, :]
430
+ k = self.head_k(x)[:, :T, :]
431
+ c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
432
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
433
+
434
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
435
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size)
436
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
437
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
438
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
439
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
440
+
441
+ x = self.head(x) + c
442
+ else:
443
+ x = self.head(x)
444
+
445
+ return x
446
+
447
+ def training_step(self, batch, batch_idx):
448
+ args = self.args
449
+ if args.my_qa_mask == 0:
450
+ idx, targets = batch
451
+ logits = self(idx)
452
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
453
+ else:
454
+ idx, targets, mask = batch
455
+ mask = mask.view(-1)
456
+ sum_mask = torch.sum(mask).item()
457
+ # if sum_mask == 0:
458
+ # return torch.tensor([0.0], requires_grad=True)
459
+
460
+ logits = self(idx)
461
+ if sum_mask == mask.shape[0]:
462
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
463
+ # print('rank', self.global_rank, 'loss', loss.item())
464
+ else:
465
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
466
+ # loss_raw = loss
467
+ loss = torch.sum(loss * mask) / sum_mask
468
+
469
+ # torch.set_printoptions(threshold=10000)
470
+ # if True: #self.global_rank == 1:
471
+ # tmp = ''
472
+ # sss = 0
473
+ # ccc = 0
474
+ # for i in range(mask.shape[0]):
475
+ # if mask[i] > 0:
476
+ # tmp += str(idx.view(-1)[i].item()) + ','
477
+ # sss += loss_raw.view(-1)[i].float().item()
478
+ # ccc += 1
479
+ # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
480
+
481
+ return L2Wrap.apply(loss, logits)
482
+
483
+ def training_step_end(self, batch_parts):
484
+ all = self.all_gather(batch_parts)
485
+ if self.trainer.is_global_zero:
486
+ self.trainer.my_loss_all = all
487
+
488
+ def generate_init_weight(self):
489
+ print(
490
+ f"""
491
+ ############################################################################
492
+ #
493
+ # Init model weight (slow for large models)...
494
+ #
495
+ ############################################################################
496
+ """
497
+ )
498
+ m = {}
499
+ for n in self.state_dict():
500
+ p = self.state_dict()[n]
501
+ shape = p.shape
502
+
503
+ gain = 1.0
504
+ scale = 1.0
505
+ if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
506
+ m[n] = p
507
+ else:
508
+ if n == "emb.weight":
509
+ scale = -1 * self.args.lr_init
510
+ else:
511
+ if shape[0] > shape[1]:
512
+ gain = math.sqrt(shape[0] / shape[1])
513
+ for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
514
+ if kk in n:
515
+ scale = 0
516
+ if n == "head.weight":
517
+ scale = 0.5
518
+ if "head_k." in n:
519
+ scale = 0.1
520
+ if "head_q." in n:
521
+ scale = 0
522
+
523
+ print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
524
+
525
+ if self.args.accelerator.upper() == "GPU":
526
+ m[n] = torch.empty((shape[0], shape[1]), device="cuda")
527
+ else:
528
+ m[n] = torch.empty((shape[0], shape[1]))
529
+
530
+ if scale == 0:
531
+ nn.init.zeros_(m[n])
532
+ elif scale < 0:
533
+ nn.init.uniform_(m[n], a=scale, b=-scale)
534
+ else:
535
+ nn.init.orthogonal_(m[n], gain=gain * scale)
536
+
537
+ m[n] = m[n].cpu()
538
+ if os.environ["RWKV_FLOAT_MODE"] == "fp16":
539
+ m[n] = m[n].half()
540
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
541
+ m[n] = m[n].bfloat16()
542
+
543
+ # if n == "emb.weight":
544
+ # print(m[n])
545
+
546
+ gc.collect()
547
+ torch.cuda.empty_cache()
548
+ return m
train.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ if __name__ == "__main__":
6
+ from argparse import ArgumentParser
7
+ from pytorch_lightning import Trainer
8
+
9
+ print("########## work in progress ##########")
10
+
11
+ ########################################################################################################
12
+ #
13
+ # example: train a simple L12-D768 RWKV on dummy data
14
+ #
15
+ # python train.py --load_model "" --wandb "" --proj_dir "out" \
16
+ # --data_file "" --data_type "dummy" --vocab_size 0 \
17
+ # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
18
+ # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
19
+ # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
20
+ # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
21
+
22
+ # example: train a simple L6-D512 RWKV from scratch on enwik8
23
+ #
24
+ # python train.py --load_model "" --wandb "" --proj_dir "out" \
25
+ # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
26
+ # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
27
+ # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
28
+ # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
29
+ # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
30
+
31
+ # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
32
+ #
33
+ # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
34
+ # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
35
+ # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
36
+ # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
37
+ # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
38
+ # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
39
+
40
+ # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
41
+ #
42
+ # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
43
+ # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
44
+ # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
45
+ # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
46
+ # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
47
+ # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
48
+
49
+ parser = ArgumentParser()
50
+
51
+ parser.add_argument("--load_model", default="", type=str) # full path, with .pth
52
+ parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
53
+ parser.add_argument("--proj_dir", default="out", type=str)
54
+ parser.add_argument("--random_seed", default="-1", type=int)
55
+
56
+ parser.add_argument("--data_file", default="", type=str)
57
+ parser.add_argument("--data_type", default="utf-8", type=str)
58
+ parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
59
+ parser.add_argument("--vocab_size_delta", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
60
+
61
+ parser.add_argument("--ctx_len", default=1024, type=int)
62
+ parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
63
+ parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
64
+ parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
65
+ parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
66
+
67
+ parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
68
+ parser.add_argument("--n_layer", default=6, type=int)
69
+ parser.add_argument("--n_embd", default=512, type=int)
70
+ parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
71
+ parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
72
+ parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
73
+ parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
74
+
75
+ parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
76
+ parser.add_argument("--lr_final", default=1e-5, type=float)
77
+ parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
78
+ parser.add_argument("--beta1", default=0.9, type=float)
79
+ parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
80
+ parser.add_argument("--adam_eps", default=1e-8, type=float)
81
+
82
+ parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
83
+ parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
84
+ parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
85
+ parser.add_argument("--my_pile_edecay", default=0, type=int)
86
+ parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
87
+ parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
88
+ # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
89
+
90
+ parser.add_argument("--my_img_version", default=0, type=str)
91
+ parser.add_argument("--my_img_size", default=0, type=int)
92
+ parser.add_argument("--my_img_bit", default=0, type=int)
93
+ parser.add_argument("--my_img_clip", default='x', type=str)
94
+ parser.add_argument("--my_img_clip_scale", default=1, type=float)
95
+ parser.add_argument("--my_img_l1_scale", default=0, type=float)
96
+ parser.add_argument("--my_img_encoder", default='x', type=str)
97
+ # parser.add_argument("--my_img_noise_scale", default=0, type=float)
98
+ parser.add_argument("--my_sample_len", default=0, type=int)
99
+ parser.add_argument("--my_ffn_shift", default=1, type=int)
100
+ parser.add_argument("--my_att_shift", default=1, type=int)
101
+ parser.add_argument("--my_pos_emb", default=0, type=int)
102
+ parser.add_argument("--load_partial", default=0, type=int)
103
+ parser.add_argument("--magic_prime", default=0, type=int)
104
+ parser.add_argument("--my_qa_mask", default=0, type=int)
105
+ parser.add_argument("--my_testing", default=0, type=int)
106
+
107
+ parser = Trainer.add_argparse_args(parser)
108
+ args = parser.parse_args()
109
+
110
+ ########################################################################################################
111
+
112
+ import os, warnings, math, datetime, sys, time
113
+ import numpy as np
114
+ import torch
115
+ from torch.utils.data import DataLoader
116
+ import deepspeed
117
+ import pytorch_lightning as pl
118
+ from pytorch_lightning import seed_everything
119
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
120
+
121
+ if args.random_seed >= 0:
122
+ print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
123
+ seed_everything(args.random_seed)
124
+
125
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
126
+ warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
127
+ warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
128
+ # os.environ["WDS_SHOW_SEED"] = "1"
129
+
130
+ args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
131
+ args.enable_checkpointing = False
132
+ args.replace_sampler_ddp = False
133
+ args.logger = False
134
+ args.gradient_clip_val = 1.0
135
+ args.num_sanity_val_steps = 0
136
+ args.check_val_every_n_epoch = int(1e20)
137
+ args.log_every_n_steps = int(1e20)
138
+ args.max_epochs = -1 # continue forever
139
+ args.betas = (args.beta1, args.beta2)
140
+ args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
141
+ os.environ["RWKV_T_MAX"] = str(args.ctx_len)
142
+
143
+ if args.data_type == "wds_img":
144
+ args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
145
+ args.proj_dir = f"{args.proj_dir}-{args.run_name}"
146
+ else:
147
+ args.run_name = f"{args.vocab_size}+{args.vocab_size_delta} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
148
+ if not os.path.exists(args.proj_dir):
149
+ os.makedirs(args.proj_dir)
150
+
151
+ if args.my_pile_stage > 0:
152
+ magic_prime_bak = args.magic_prime
153
+ if args.ctx_len == 1024:
154
+ args.magic_prime = 324331313
155
+ args.epoch_count = 8043
156
+ elif args.ctx_len == 2048:
157
+ args.magic_prime = 162165671
158
+ args.epoch_count = 4021
159
+ elif args.ctx_len == 4096:
160
+ args.magic_prime = 81082817
161
+ args.epoch_count = 2010
162
+ if args.my_pile_shift < 0:
163
+ if args.ctx_len == 1024:
164
+ args.my_pile_shift = 0
165
+ elif args.ctx_len == 2048:
166
+ args.my_pile_shift = 512
167
+ elif args.ctx_len == 4096:
168
+ args.my_pile_shift = 768
169
+
170
+ if magic_prime_bak > 0:
171
+ args.magic_prime = magic_prime_bak
172
+
173
+ args.epoch_steps = 40320 // args.real_bsz
174
+ assert args.epoch_steps * args.real_bsz == 40320
175
+ if args.my_pile_stage == 2:
176
+ assert args.lr_final == args.lr_init
177
+ if args.my_pile_stage >= 2: # find latest saved model
178
+ list_p = []
179
+ for p in os.listdir(args.proj_dir):
180
+ if p.startswith("rwkv") and p.endswith(".pth"):
181
+ p = ((p.split("-"))[1].split("."))[0]
182
+ if p == "init":
183
+ p = -1
184
+ else:
185
+ p = int(p)
186
+ list_p += [p]
187
+ list_p.sort()
188
+ max_p = list_p[-1]
189
+ if len(list_p) > 1:
190
+ args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
191
+ if max_p == -1:
192
+ args.load_model = f"{args.proj_dir}/rwkv-init.pth"
193
+ else:
194
+ args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
195
+ if args.my_pile_stage == 2:
196
+ args.warmup_steps = 10
197
+ else:
198
+ args.warmup_steps = 30
199
+ args.epoch_begin = max_p + 1
200
+
201
+ samples_per_epoch = args.epoch_steps * args.real_bsz
202
+ tokens_per_epoch = samples_per_epoch * args.ctx_len
203
+ rank_zero_info(
204
+ f"""
205
+ ############################################################################
206
+ #
207
+ # RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
208
+ #
209
+ # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
210
+ #
211
+ # Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
212
+ #
213
+ # Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
214
+ #
215
+ # Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
216
+ #
217
+ # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
218
+ #
219
+ # Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
220
+ # Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
221
+ # Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
222
+ #
223
+ ############################################################################
224
+ """
225
+ )
226
+ rank_zero_info(str(vars(args)) + "\n")
227
+
228
+ assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
229
+
230
+ if args.lr_final == 0 or args.lr_init == 0:
231
+ rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
232
+
233
+ assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
234
+ os.environ["RWKV_FLOAT_MODE"] = args.precision
235
+ if args.precision == "fp32":
236
+ rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
237
+ if args.precision == "fp16":
238
+ rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
239
+
240
+ os.environ["RWKV_JIT_ON"] = "1"
241
+ if "deepspeed_stage_3" in args.strategy:
242
+ os.environ["RWKV_JIT_ON"] = "0"
243
+
244
+ torch.backends.cudnn.benchmark = True
245
+ torch.backends.cudnn.enabled = True
246
+ if args.precision == "fp32":
247
+ torch.backends.cudnn.allow_tf32 = False
248
+ torch.backends.cuda.matmul.allow_tf32 = False
249
+ else:
250
+ torch.backends.cudnn.allow_tf32 = True
251
+ torch.backends.cuda.matmul.allow_tf32 = True
252
+
253
+ if "32" in args.precision:
254
+ args.precision = 32
255
+ elif args.precision == "fp16":
256
+ args.precision = 16
257
+ else:
258
+ args.precision = "bf16"
259
+
260
+ ########################################################################################################
261
+
262
+ from src.trainer import train_callback, generate_init_weight
263
+ from src.dataset import MyDataset
264
+
265
+ train_data = MyDataset(args)
266
+ args.vocab_size = train_data.vocab_size
267
+
268
+ if args.data_type == 'wds_img':
269
+ from src.model_img import RWKV_IMG
270
+ model = RWKV_IMG(args)
271
+ else:
272
+ from src.model import RWKV
273
+ model = RWKV(args)
274
+
275
+ if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
276
+ init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
277
+ generate_init_weight(model, init_weight_name) # save initial weights
278
+ args.load_model = init_weight_name
279
+
280
+ print(f"########## Loading {args.load_model}... ##########")
281
+ try:
282
+ load_dict = torch.load(args.load_model, map_location="cpu")
283
+ except:
284
+ print(f"Bad checkpoint {args.load_model}")
285
+ if args.my_pile_stage >= 2: # try again using another checkpoint
286
+ max_p = args.my_pile_prev_p
287
+ if max_p == -1:
288
+ args.load_model = f"{args.proj_dir}/rwkv-init.pth"
289
+ else:
290
+ args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
291
+ args.epoch_begin = max_p + 1
292
+ print(f"Trying {args.load_model}")
293
+ load_dict = torch.load(args.load_model, map_location="cpu")
294
+
295
+ if args.load_partial == 1:
296
+ load_keys = load_dict.keys()
297
+ for k in model.state_dict():
298
+ if k not in load_keys:
299
+ load_dict[k] = model.state_dict()[k]
300
+ model.load_state_dict(load_dict)
301
+ if args.vocab_size_delta > 0:
302
+ # model.cuda()
303
+ model.resize_emb(args.vocab_size + args.vocab_size_delta)
304
+ args.vocab_size = args.vocab_size + args.vocab_size_delta
305
+
306
+ trainer = Trainer.from_argparse_args(
307
+ args,
308
+ callbacks=[train_callback(args)],
309
+ )
310
+ if "deepspeed" in args.strategy:
311
+ trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
312
+ trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
313
+
314
+ # must set shuffle=False, persistent_workers=False (because worker is in another thread)
315
+ data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
316
+
317
+ trainer.fit(model, data_loader)