Upload 2 files
Browse filesWho's getting the best head?
model.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
import os, math, gc
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
11 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
12 |
+
import deepspeed
|
13 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
14 |
+
|
15 |
+
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
16 |
+
|
17 |
+
|
18 |
+
def __nop(ob):
|
19 |
+
return ob
|
20 |
+
|
21 |
+
|
22 |
+
MyModule = nn.Module
|
23 |
+
MyFunction = __nop
|
24 |
+
if os.environ["RWKV_JIT_ON"] == "1":
|
25 |
+
MyModule = torch.jit.ScriptModule
|
26 |
+
MyFunction = torch.jit.script_method
|
27 |
+
|
28 |
+
|
29 |
+
########################################################################################################
|
30 |
+
# CUDA Kernel
|
31 |
+
########################################################################################################
|
32 |
+
|
33 |
+
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
34 |
+
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
35 |
+
|
36 |
+
from torch.utils.cpp_extension import load
|
37 |
+
|
38 |
+
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
|
39 |
+
|
40 |
+
|
41 |
+
class WKV(torch.autograd.Function):
|
42 |
+
@staticmethod
|
43 |
+
def forward(ctx, B, T, C, w, u, k, v):
|
44 |
+
ctx.B = B
|
45 |
+
ctx.T = T
|
46 |
+
ctx.C = C
|
47 |
+
assert T <= T_MAX
|
48 |
+
assert B * C % min(C, 32) == 0
|
49 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
50 |
+
w = -torch.exp(w.contiguous())
|
51 |
+
u = u.contiguous()
|
52 |
+
k = k.contiguous()
|
53 |
+
v = v.contiguous()
|
54 |
+
else:
|
55 |
+
w = -torch.exp(w.float().contiguous())
|
56 |
+
u = u.float().contiguous()
|
57 |
+
k = k.float().contiguous()
|
58 |
+
v = v.float().contiguous()
|
59 |
+
ctx.save_for_backward(w, u, k, v)
|
60 |
+
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
61 |
+
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
62 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
63 |
+
return y
|
64 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
65 |
+
return y.half()
|
66 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
67 |
+
return y.bfloat16()
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, gy):
|
71 |
+
B = ctx.B
|
72 |
+
T = ctx.T
|
73 |
+
C = ctx.C
|
74 |
+
assert T <= T_MAX
|
75 |
+
assert B * C % min(C, 32) == 0
|
76 |
+
w, u, k, v = ctx.saved_tensors
|
77 |
+
gw = torch.zeros((B, C), device=gy.device).contiguous()
|
78 |
+
gu = torch.zeros((B, C), device=gy.device).contiguous()
|
79 |
+
gk = torch.zeros((B, T, C), device=gy.device).contiguous()
|
80 |
+
gv = torch.zeros((B, T, C), device=gy.device).contiguous()
|
81 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
82 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
83 |
+
else:
|
84 |
+
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
85 |
+
gw = torch.sum(gw, dim=0)
|
86 |
+
gu = torch.sum(gu, dim=0)
|
87 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
88 |
+
return (None, None, None, gw, gu, gk, gv)
|
89 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
90 |
+
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
91 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
92 |
+
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
93 |
+
|
94 |
+
|
95 |
+
def RUN_CUDA(B, T, C, w, u, k, v):
|
96 |
+
return WKV.apply(B, T, C, w, u, k, v)
|
97 |
+
|
98 |
+
|
99 |
+
########################################################################################################
|
100 |
+
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
101 |
+
########################################################################################################
|
102 |
+
|
103 |
+
|
104 |
+
class RWKV_TimeMix(MyModule):
|
105 |
+
def __init__(self, args, layer_id):
|
106 |
+
super().__init__()
|
107 |
+
self.args = args
|
108 |
+
self.layer_id = layer_id
|
109 |
+
self.ctx_len = args.ctx_len
|
110 |
+
self.n_embd = args.n_embd
|
111 |
+
self.my_testing = self.args.my_testing
|
112 |
+
attn_sz = args.n_embd
|
113 |
+
|
114 |
+
with torch.no_grad(): # fancy init
|
115 |
+
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
116 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
117 |
+
|
118 |
+
# fancy time_decay
|
119 |
+
decay_speed = torch.ones(attn_sz)
|
120 |
+
for h in range(attn_sz):
|
121 |
+
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
122 |
+
self.time_decay = nn.Parameter(decay_speed)
|
123 |
+
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
124 |
+
|
125 |
+
# fancy time_first
|
126 |
+
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
|
127 |
+
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
128 |
+
|
129 |
+
# fancy time_mix
|
130 |
+
x = torch.ones(1, 1, args.n_embd)
|
131 |
+
for i in range(args.n_embd):
|
132 |
+
x[0, 0, i] = i / args.n_embd
|
133 |
+
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
134 |
+
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
135 |
+
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
136 |
+
|
137 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
138 |
+
|
139 |
+
self.key = nn.Linear(args.n_embd, attn_sz, bias=False)
|
140 |
+
self.value = nn.Linear(args.n_embd, attn_sz, bias=False)
|
141 |
+
self.receptance = nn.Linear(args.n_embd, attn_sz, bias=False)
|
142 |
+
|
143 |
+
self.output = nn.Linear(attn_sz, args.n_embd, bias=False)
|
144 |
+
|
145 |
+
# if self.my_testing > 0:
|
146 |
+
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
147 |
+
|
148 |
+
@MyFunction
|
149 |
+
def jit_func(self, x):
|
150 |
+
|
151 |
+
# Mix x with the previous timestep to produce xk, xv, xr
|
152 |
+
xx = self.time_shift(x)
|
153 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
154 |
+
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
155 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
156 |
+
|
157 |
+
# Use xk, xv, xr to produce k, v, r
|
158 |
+
k = self.key(xk)
|
159 |
+
v = self.value(xv)
|
160 |
+
r = self.receptance(xr)
|
161 |
+
sr = torch.sigmoid(r)
|
162 |
+
|
163 |
+
return sr, k, v
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
B, T, C = x.size() # x = (Batch,Time,Channel)
|
167 |
+
|
168 |
+
sr, k, v = self.jit_func(x)
|
169 |
+
|
170 |
+
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
171 |
+
rwkv = self.output(rwkv)
|
172 |
+
return rwkv
|
173 |
+
|
174 |
+
|
175 |
+
class RWKV_ChannelMix(MyModule):
|
176 |
+
def __init__(self, args, layer_id):
|
177 |
+
super().__init__()
|
178 |
+
self.args = args
|
179 |
+
self.layer_id = layer_id
|
180 |
+
self.my_testing = self.args.my_testing
|
181 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
182 |
+
|
183 |
+
with torch.no_grad(): # fancy init of time_mix
|
184 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
185 |
+
|
186 |
+
x = torch.ones(1, 1, args.n_embd)
|
187 |
+
for i in range(args.n_embd):
|
188 |
+
x[0, 0, i] = i / args.n_embd
|
189 |
+
|
190 |
+
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
191 |
+
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
192 |
+
|
193 |
+
hidden_sz = 4 * args.n_embd
|
194 |
+
self.key = nn.Linear(args.n_embd, hidden_sz, bias=False)
|
195 |
+
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
196 |
+
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)
|
197 |
+
|
198 |
+
# if self.my_testing in [1]:
|
199 |
+
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
|
200 |
+
# elif self.my_testing in [2]:
|
201 |
+
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))
|
202 |
+
|
203 |
+
|
204 |
+
@MyFunction
|
205 |
+
def forward(self, x):
|
206 |
+
xx = self.time_shift(x)
|
207 |
+
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
208 |
+
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
209 |
+
|
210 |
+
k = self.key(xk)
|
211 |
+
k = torch.square(torch.relu(k))
|
212 |
+
kv = self.value(k)
|
213 |
+
|
214 |
+
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
215 |
+
return rkv
|
216 |
+
|
217 |
+
# k = self.key(xk)
|
218 |
+
# # if self.my_testing in [0, 2]:
|
219 |
+
# k = torch.square(torch.relu(k))
|
220 |
+
# # elif self.my_testing == 1:
|
221 |
+
# # k = torch.square(torch.relu(k)) + k * self.aaa
|
222 |
+
# kv = self.value(k)
|
223 |
+
# r = self.receptance(xr)
|
224 |
+
# # if self.my_testing == 0:
|
225 |
+
# r = torch.sigmoid(r)
|
226 |
+
# # elif self.my_testing == 2:
|
227 |
+
# # r = torch.sigmoid(r) + r * self.aaa
|
228 |
+
# rkv = r * kv
|
229 |
+
# return rkv
|
230 |
+
|
231 |
+
########################################################################################################
|
232 |
+
# The RWKV Model with our blocks
|
233 |
+
########################################################################################################
|
234 |
+
|
235 |
+
|
236 |
+
class Block(nn.Module):
|
237 |
+
def __init__(self, args, layer_id):
|
238 |
+
super().__init__()
|
239 |
+
self.args = args
|
240 |
+
self.layer_id = layer_id
|
241 |
+
|
242 |
+
self.ln1 = nn.LayerNorm(args.n_embd)
|
243 |
+
self.ln2 = nn.LayerNorm(args.n_embd)
|
244 |
+
|
245 |
+
if self.layer_id == 0:
|
246 |
+
self.ln0 = nn.LayerNorm(args.n_embd)
|
247 |
+
if args.my_pos_emb > 0:
|
248 |
+
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
249 |
+
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
250 |
+
|
251 |
+
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
252 |
+
self.ffnPre = RWKV_ChannelMix(args, 0)
|
253 |
+
else:
|
254 |
+
self.att = RWKV_TimeMix(args, layer_id)
|
255 |
+
|
256 |
+
self.ffn = RWKV_ChannelMix(args, layer_id)
|
257 |
+
|
258 |
+
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
259 |
+
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
260 |
+
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
261 |
+
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
262 |
+
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
263 |
+
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
264 |
+
|
265 |
+
def forward(self, x, x_emb=None):
|
266 |
+
args = self.args
|
267 |
+
B, T, C = x.size()
|
268 |
+
if self.layer_id == 0:
|
269 |
+
x = self.ln0(x)
|
270 |
+
if args.my_pos_emb > 0:
|
271 |
+
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
272 |
+
x = x + pos_emb
|
273 |
+
|
274 |
+
if self.layer_id == 0 and args.pre_ffn > 0:
|
275 |
+
x = x + self.ffnPre(self.ln1(x))
|
276 |
+
else:
|
277 |
+
x = x + self.att(self.ln1(x))
|
278 |
+
x = x + self.ffn(self.ln2(x))
|
279 |
+
|
280 |
+
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
281 |
+
xx = self.tiny_ln(x)
|
282 |
+
q = self.tiny_q(xx)[:, :T, :]
|
283 |
+
k = self.tiny_k(xx)[:, :T, :]
|
284 |
+
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
285 |
+
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
286 |
+
x = x + c @ self.tiny_v(x_emb)
|
287 |
+
return x
|
288 |
+
|
289 |
+
|
290 |
+
class L2Wrap(torch.autograd.Function):
|
291 |
+
@staticmethod
|
292 |
+
def forward(ctx, loss, y):
|
293 |
+
ctx.save_for_backward(y)
|
294 |
+
return loss
|
295 |
+
|
296 |
+
@staticmethod
|
297 |
+
def backward(ctx, grad_output):
|
298 |
+
y = ctx.saved_tensors[0]
|
299 |
+
# to encourage the logits to be close to 0
|
300 |
+
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
301 |
+
maxx, ids = torch.max(y, -1, keepdim=True)
|
302 |
+
gy = torch.zeros_like(y)
|
303 |
+
gy.scatter_(-1, ids, maxx * factor)
|
304 |
+
return (grad_output, gy)
|
305 |
+
|
306 |
+
|
307 |
+
class RWKV(pl.LightningModule):
|
308 |
+
def __init__(self, args):
|
309 |
+
super().__init__()
|
310 |
+
self.args = args
|
311 |
+
|
312 |
+
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
313 |
+
|
314 |
+
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
315 |
+
|
316 |
+
self.ln_out = nn.LayerNorm(args.n_embd)
|
317 |
+
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
318 |
+
|
319 |
+
if args.head_qk > 0:
|
320 |
+
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
321 |
+
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
322 |
+
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
323 |
+
|
324 |
+
def resize_emb(self, new_tokens: int):
|
325 |
+
print(f"### RESIZING MODEL TO {new_tokens} TOKENS ###")
|
326 |
+
|
327 |
+
new_embed = nn.Embedding(new_tokens, self.args.n_embd)
|
328 |
+
new_embed.to(self.emb.weight.device, dtype=self.emb.weight.dtype)
|
329 |
+
nn.init.zeros_(new_embed.weight)
|
330 |
+
|
331 |
+
n = min(self.args.vocab_size, new_tokens)
|
332 |
+
print("### Start emb copy", new_embed.weight.size(), self.emb.weight.size())
|
333 |
+
new_embed.weight.data[:n, :] = self.emb.weight.data[:n, :]
|
334 |
+
self.emb = new_embed
|
335 |
+
print("### emb copy end")
|
336 |
+
|
337 |
+
# Now we resize head
|
338 |
+
new_head = nn.Linear(self.args.n_embd, new_tokens, bias=False)
|
339 |
+
new_head.to(self.head.weight.device, dtype=self.head.weight.dtype)
|
340 |
+
nn.init.orthogonal_(new_head.weight, gain=1 * 0.5)
|
341 |
+
|
342 |
+
print("### Start head copy", new_head.weight.size(), self.head.weight.size())
|
343 |
+
new_head.weight.data[:n, :] = self.head.weight.data[:n, :]
|
344 |
+
self.head = new_head
|
345 |
+
print("### RESIZE END")
|
346 |
+
|
347 |
+
def configure_optimizers(self):
|
348 |
+
args = self.args
|
349 |
+
if args.layerwise_lr > 0:
|
350 |
+
lr_1x = set()
|
351 |
+
lr_2x = set()
|
352 |
+
lr_3x = set()
|
353 |
+
for n, p in self.named_parameters():
|
354 |
+
if "time_mix" in n:
|
355 |
+
if args.my_pile_stage == 2:
|
356 |
+
lr_2x.add(n)
|
357 |
+
else:
|
358 |
+
lr_1x.add(n)
|
359 |
+
elif "time_decay" in n:
|
360 |
+
if args.my_pile_stage == 2:
|
361 |
+
lr_3x.add(n)
|
362 |
+
else:
|
363 |
+
lr_2x.add(n)
|
364 |
+
elif "time_first" in n:
|
365 |
+
lr_3x.add(n)
|
366 |
+
else:
|
367 |
+
lr_1x.add(n)
|
368 |
+
lr_1x = sorted(list(lr_1x))
|
369 |
+
lr_2x = sorted(list(lr_2x))
|
370 |
+
lr_3x = sorted(list(lr_3x))
|
371 |
+
# print('1x', lr_1x)
|
372 |
+
# print('2x', lr_2x)
|
373 |
+
# print('3x', lr_3x)
|
374 |
+
param_dict = {n: p for n, p in self.named_parameters()}
|
375 |
+
if args.my_pile_stage == 2:
|
376 |
+
optim_groups = [
|
377 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
378 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
379 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
380 |
+
]
|
381 |
+
else:
|
382 |
+
optim_groups = [
|
383 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
384 |
+
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
385 |
+
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
386 |
+
]
|
387 |
+
else:
|
388 |
+
optim_groups = [
|
389 |
+
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
390 |
+
]
|
391 |
+
|
392 |
+
if self.deepspeed_offload:
|
393 |
+
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
394 |
+
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
395 |
+
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
396 |
+
|
397 |
+
@property
|
398 |
+
def deepspeed_offload(self) -> bool:
|
399 |
+
strategy = self.trainer.strategy
|
400 |
+
if isinstance(strategy, DeepSpeedStrategy):
|
401 |
+
cfg = strategy.config["zero_optimization"]
|
402 |
+
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
403 |
+
return False
|
404 |
+
|
405 |
+
def forward(self, idx):
|
406 |
+
args = self.args
|
407 |
+
B, T = idx.size()
|
408 |
+
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
409 |
+
|
410 |
+
x = self.emb(idx)
|
411 |
+
x_emb = x
|
412 |
+
|
413 |
+
if args.tiny_att_dim > 0:
|
414 |
+
for block in self.blocks:
|
415 |
+
if args.grad_cp == 1:
|
416 |
+
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
417 |
+
else:
|
418 |
+
x = block(x, x_emb)
|
419 |
+
else:
|
420 |
+
for block in self.blocks:
|
421 |
+
if args.grad_cp == 1:
|
422 |
+
x = deepspeed.checkpointing.checkpoint(block, x)
|
423 |
+
else:
|
424 |
+
x = block(x)
|
425 |
+
|
426 |
+
x = self.ln_out(x)
|
427 |
+
|
428 |
+
if args.head_qk > 0:
|
429 |
+
q = self.head_q(x)[:, :T, :]
|
430 |
+
k = self.head_k(x)[:, :T, :]
|
431 |
+
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
432 |
+
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
433 |
+
|
434 |
+
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
435 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
436 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
437 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
438 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
439 |
+
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
440 |
+
|
441 |
+
x = self.head(x) + c
|
442 |
+
else:
|
443 |
+
x = self.head(x)
|
444 |
+
|
445 |
+
return x
|
446 |
+
|
447 |
+
def training_step(self, batch, batch_idx):
|
448 |
+
args = self.args
|
449 |
+
if args.my_qa_mask == 0:
|
450 |
+
idx, targets = batch
|
451 |
+
logits = self(idx)
|
452 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
453 |
+
else:
|
454 |
+
idx, targets, mask = batch
|
455 |
+
mask = mask.view(-1)
|
456 |
+
sum_mask = torch.sum(mask).item()
|
457 |
+
# if sum_mask == 0:
|
458 |
+
# return torch.tensor([0.0], requires_grad=True)
|
459 |
+
|
460 |
+
logits = self(idx)
|
461 |
+
if sum_mask == mask.shape[0]:
|
462 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
463 |
+
# print('rank', self.global_rank, 'loss', loss.item())
|
464 |
+
else:
|
465 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
466 |
+
# loss_raw = loss
|
467 |
+
loss = torch.sum(loss * mask) / sum_mask
|
468 |
+
|
469 |
+
# torch.set_printoptions(threshold=10000)
|
470 |
+
# if True: #self.global_rank == 1:
|
471 |
+
# tmp = ''
|
472 |
+
# sss = 0
|
473 |
+
# ccc = 0
|
474 |
+
# for i in range(mask.shape[0]):
|
475 |
+
# if mask[i] > 0:
|
476 |
+
# tmp += str(idx.view(-1)[i].item()) + ','
|
477 |
+
# sss += loss_raw.view(-1)[i].float().item()
|
478 |
+
# ccc += 1
|
479 |
+
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
480 |
+
|
481 |
+
return L2Wrap.apply(loss, logits)
|
482 |
+
|
483 |
+
def training_step_end(self, batch_parts):
|
484 |
+
all = self.all_gather(batch_parts)
|
485 |
+
if self.trainer.is_global_zero:
|
486 |
+
self.trainer.my_loss_all = all
|
487 |
+
|
488 |
+
def generate_init_weight(self):
|
489 |
+
print(
|
490 |
+
f"""
|
491 |
+
############################################################################
|
492 |
+
#
|
493 |
+
# Init model weight (slow for large models)...
|
494 |
+
#
|
495 |
+
############################################################################
|
496 |
+
"""
|
497 |
+
)
|
498 |
+
m = {}
|
499 |
+
for n in self.state_dict():
|
500 |
+
p = self.state_dict()[n]
|
501 |
+
shape = p.shape
|
502 |
+
|
503 |
+
gain = 1.0
|
504 |
+
scale = 1.0
|
505 |
+
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n:
|
506 |
+
m[n] = p
|
507 |
+
else:
|
508 |
+
if n == "emb.weight":
|
509 |
+
scale = -1 * self.args.lr_init
|
510 |
+
else:
|
511 |
+
if shape[0] > shape[1]:
|
512 |
+
gain = math.sqrt(shape[0] / shape[1])
|
513 |
+
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q."]:
|
514 |
+
if kk in n:
|
515 |
+
scale = 0
|
516 |
+
if n == "head.weight":
|
517 |
+
scale = 0.5
|
518 |
+
if "head_k." in n:
|
519 |
+
scale = 0.1
|
520 |
+
if "head_q." in n:
|
521 |
+
scale = 0
|
522 |
+
|
523 |
+
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
524 |
+
|
525 |
+
if self.args.accelerator.upper() == "GPU":
|
526 |
+
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
527 |
+
else:
|
528 |
+
m[n] = torch.empty((shape[0], shape[1]))
|
529 |
+
|
530 |
+
if scale == 0:
|
531 |
+
nn.init.zeros_(m[n])
|
532 |
+
elif scale < 0:
|
533 |
+
nn.init.uniform_(m[n], a=scale, b=-scale)
|
534 |
+
else:
|
535 |
+
nn.init.orthogonal_(m[n], gain=gain * scale)
|
536 |
+
|
537 |
+
m[n] = m[n].cpu()
|
538 |
+
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
539 |
+
m[n] = m[n].half()
|
540 |
+
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
541 |
+
m[n] = m[n].bfloat16()
|
542 |
+
|
543 |
+
# if n == "emb.weight":
|
544 |
+
# print(m[n])
|
545 |
+
|
546 |
+
gc.collect()
|
547 |
+
torch.cuda.empty_cache()
|
548 |
+
return m
|
train.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################################################################################################
|
2 |
+
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
3 |
+
########################################################################################################
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pytorch_lightning import Trainer
|
8 |
+
|
9 |
+
print("########## work in progress ##########")
|
10 |
+
|
11 |
+
########################################################################################################
|
12 |
+
#
|
13 |
+
# example: train a simple L12-D768 RWKV on dummy data
|
14 |
+
#
|
15 |
+
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
16 |
+
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
17 |
+
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
18 |
+
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
19 |
+
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
20 |
+
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
21 |
+
|
22 |
+
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
23 |
+
#
|
24 |
+
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
25 |
+
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
26 |
+
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
27 |
+
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
28 |
+
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
29 |
+
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
30 |
+
|
31 |
+
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
32 |
+
#
|
33 |
+
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
34 |
+
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
35 |
+
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
|
36 |
+
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
37 |
+
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
38 |
+
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
|
39 |
+
|
40 |
+
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
|
41 |
+
#
|
42 |
+
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
43 |
+
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
44 |
+
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
|
45 |
+
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
46 |
+
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
47 |
+
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
|
48 |
+
|
49 |
+
parser = ArgumentParser()
|
50 |
+
|
51 |
+
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
52 |
+
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
53 |
+
parser.add_argument("--proj_dir", default="out", type=str)
|
54 |
+
parser.add_argument("--random_seed", default="-1", type=int)
|
55 |
+
|
56 |
+
parser.add_argument("--data_file", default="", type=str)
|
57 |
+
parser.add_argument("--data_type", default="utf-8", type=str)
|
58 |
+
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
59 |
+
parser.add_argument("--vocab_size_delta", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
60 |
+
|
61 |
+
parser.add_argument("--ctx_len", default=1024, type=int)
|
62 |
+
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
63 |
+
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
64 |
+
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
65 |
+
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
66 |
+
|
67 |
+
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
68 |
+
parser.add_argument("--n_layer", default=6, type=int)
|
69 |
+
parser.add_argument("--n_embd", default=512, type=int)
|
70 |
+
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
71 |
+
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
72 |
+
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
73 |
+
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
74 |
+
|
75 |
+
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
76 |
+
parser.add_argument("--lr_final", default=1e-5, type=float)
|
77 |
+
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
78 |
+
parser.add_argument("--beta1", default=0.9, type=float)
|
79 |
+
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
80 |
+
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
81 |
+
|
82 |
+
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
83 |
+
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
84 |
+
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
85 |
+
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
86 |
+
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
87 |
+
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
88 |
+
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
89 |
+
|
90 |
+
parser.add_argument("--my_img_version", default=0, type=str)
|
91 |
+
parser.add_argument("--my_img_size", default=0, type=int)
|
92 |
+
parser.add_argument("--my_img_bit", default=0, type=int)
|
93 |
+
parser.add_argument("--my_img_clip", default='x', type=str)
|
94 |
+
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
95 |
+
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
96 |
+
parser.add_argument("--my_img_encoder", default='x', type=str)
|
97 |
+
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
98 |
+
parser.add_argument("--my_sample_len", default=0, type=int)
|
99 |
+
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
100 |
+
parser.add_argument("--my_att_shift", default=1, type=int)
|
101 |
+
parser.add_argument("--my_pos_emb", default=0, type=int)
|
102 |
+
parser.add_argument("--load_partial", default=0, type=int)
|
103 |
+
parser.add_argument("--magic_prime", default=0, type=int)
|
104 |
+
parser.add_argument("--my_qa_mask", default=0, type=int)
|
105 |
+
parser.add_argument("--my_testing", default=0, type=int)
|
106 |
+
|
107 |
+
parser = Trainer.add_argparse_args(parser)
|
108 |
+
args = parser.parse_args()
|
109 |
+
|
110 |
+
########################################################################################################
|
111 |
+
|
112 |
+
import os, warnings, math, datetime, sys, time
|
113 |
+
import numpy as np
|
114 |
+
import torch
|
115 |
+
from torch.utils.data import DataLoader
|
116 |
+
import deepspeed
|
117 |
+
import pytorch_lightning as pl
|
118 |
+
from pytorch_lightning import seed_everything
|
119 |
+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
120 |
+
|
121 |
+
if args.random_seed >= 0:
|
122 |
+
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
123 |
+
seed_everything(args.random_seed)
|
124 |
+
|
125 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
126 |
+
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
127 |
+
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
128 |
+
# os.environ["WDS_SHOW_SEED"] = "1"
|
129 |
+
|
130 |
+
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
131 |
+
args.enable_checkpointing = False
|
132 |
+
args.replace_sampler_ddp = False
|
133 |
+
args.logger = False
|
134 |
+
args.gradient_clip_val = 1.0
|
135 |
+
args.num_sanity_val_steps = 0
|
136 |
+
args.check_val_every_n_epoch = int(1e20)
|
137 |
+
args.log_every_n_steps = int(1e20)
|
138 |
+
args.max_epochs = -1 # continue forever
|
139 |
+
args.betas = (args.beta1, args.beta2)
|
140 |
+
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
141 |
+
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
142 |
+
|
143 |
+
if args.data_type == "wds_img":
|
144 |
+
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
145 |
+
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
146 |
+
else:
|
147 |
+
args.run_name = f"{args.vocab_size}+{args.vocab_size_delta} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
148 |
+
if not os.path.exists(args.proj_dir):
|
149 |
+
os.makedirs(args.proj_dir)
|
150 |
+
|
151 |
+
if args.my_pile_stage > 0:
|
152 |
+
magic_prime_bak = args.magic_prime
|
153 |
+
if args.ctx_len == 1024:
|
154 |
+
args.magic_prime = 324331313
|
155 |
+
args.epoch_count = 8043
|
156 |
+
elif args.ctx_len == 2048:
|
157 |
+
args.magic_prime = 162165671
|
158 |
+
args.epoch_count = 4021
|
159 |
+
elif args.ctx_len == 4096:
|
160 |
+
args.magic_prime = 81082817
|
161 |
+
args.epoch_count = 2010
|
162 |
+
if args.my_pile_shift < 0:
|
163 |
+
if args.ctx_len == 1024:
|
164 |
+
args.my_pile_shift = 0
|
165 |
+
elif args.ctx_len == 2048:
|
166 |
+
args.my_pile_shift = 512
|
167 |
+
elif args.ctx_len == 4096:
|
168 |
+
args.my_pile_shift = 768
|
169 |
+
|
170 |
+
if magic_prime_bak > 0:
|
171 |
+
args.magic_prime = magic_prime_bak
|
172 |
+
|
173 |
+
args.epoch_steps = 40320 // args.real_bsz
|
174 |
+
assert args.epoch_steps * args.real_bsz == 40320
|
175 |
+
if args.my_pile_stage == 2:
|
176 |
+
assert args.lr_final == args.lr_init
|
177 |
+
if args.my_pile_stage >= 2: # find latest saved model
|
178 |
+
list_p = []
|
179 |
+
for p in os.listdir(args.proj_dir):
|
180 |
+
if p.startswith("rwkv") and p.endswith(".pth"):
|
181 |
+
p = ((p.split("-"))[1].split("."))[0]
|
182 |
+
if p == "init":
|
183 |
+
p = -1
|
184 |
+
else:
|
185 |
+
p = int(p)
|
186 |
+
list_p += [p]
|
187 |
+
list_p.sort()
|
188 |
+
max_p = list_p[-1]
|
189 |
+
if len(list_p) > 1:
|
190 |
+
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
191 |
+
if max_p == -1:
|
192 |
+
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
193 |
+
else:
|
194 |
+
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
195 |
+
if args.my_pile_stage == 2:
|
196 |
+
args.warmup_steps = 10
|
197 |
+
else:
|
198 |
+
args.warmup_steps = 30
|
199 |
+
args.epoch_begin = max_p + 1
|
200 |
+
|
201 |
+
samples_per_epoch = args.epoch_steps * args.real_bsz
|
202 |
+
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
203 |
+
rank_zero_info(
|
204 |
+
f"""
|
205 |
+
############################################################################
|
206 |
+
#
|
207 |
+
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
208 |
+
#
|
209 |
+
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
210 |
+
#
|
211 |
+
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
212 |
+
#
|
213 |
+
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
214 |
+
#
|
215 |
+
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
216 |
+
#
|
217 |
+
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
218 |
+
#
|
219 |
+
# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer
|
220 |
+
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
|
221 |
+
# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer
|
222 |
+
#
|
223 |
+
############################################################################
|
224 |
+
"""
|
225 |
+
)
|
226 |
+
rank_zero_info(str(vars(args)) + "\n")
|
227 |
+
|
228 |
+
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
229 |
+
|
230 |
+
if args.lr_final == 0 or args.lr_init == 0:
|
231 |
+
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
232 |
+
|
233 |
+
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
234 |
+
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
235 |
+
if args.precision == "fp32":
|
236 |
+
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
237 |
+
if args.precision == "fp16":
|
238 |
+
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
239 |
+
|
240 |
+
os.environ["RWKV_JIT_ON"] = "1"
|
241 |
+
if "deepspeed_stage_3" in args.strategy:
|
242 |
+
os.environ["RWKV_JIT_ON"] = "0"
|
243 |
+
|
244 |
+
torch.backends.cudnn.benchmark = True
|
245 |
+
torch.backends.cudnn.enabled = True
|
246 |
+
if args.precision == "fp32":
|
247 |
+
torch.backends.cudnn.allow_tf32 = False
|
248 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
249 |
+
else:
|
250 |
+
torch.backends.cudnn.allow_tf32 = True
|
251 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
252 |
+
|
253 |
+
if "32" in args.precision:
|
254 |
+
args.precision = 32
|
255 |
+
elif args.precision == "fp16":
|
256 |
+
args.precision = 16
|
257 |
+
else:
|
258 |
+
args.precision = "bf16"
|
259 |
+
|
260 |
+
########################################################################################################
|
261 |
+
|
262 |
+
from src.trainer import train_callback, generate_init_weight
|
263 |
+
from src.dataset import MyDataset
|
264 |
+
|
265 |
+
train_data = MyDataset(args)
|
266 |
+
args.vocab_size = train_data.vocab_size
|
267 |
+
|
268 |
+
if args.data_type == 'wds_img':
|
269 |
+
from src.model_img import RWKV_IMG
|
270 |
+
model = RWKV_IMG(args)
|
271 |
+
else:
|
272 |
+
from src.model import RWKV
|
273 |
+
model = RWKV(args)
|
274 |
+
|
275 |
+
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
276 |
+
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
277 |
+
generate_init_weight(model, init_weight_name) # save initial weights
|
278 |
+
args.load_model = init_weight_name
|
279 |
+
|
280 |
+
print(f"########## Loading {args.load_model}... ##########")
|
281 |
+
try:
|
282 |
+
load_dict = torch.load(args.load_model, map_location="cpu")
|
283 |
+
except:
|
284 |
+
print(f"Bad checkpoint {args.load_model}")
|
285 |
+
if args.my_pile_stage >= 2: # try again using another checkpoint
|
286 |
+
max_p = args.my_pile_prev_p
|
287 |
+
if max_p == -1:
|
288 |
+
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
289 |
+
else:
|
290 |
+
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
291 |
+
args.epoch_begin = max_p + 1
|
292 |
+
print(f"Trying {args.load_model}")
|
293 |
+
load_dict = torch.load(args.load_model, map_location="cpu")
|
294 |
+
|
295 |
+
if args.load_partial == 1:
|
296 |
+
load_keys = load_dict.keys()
|
297 |
+
for k in model.state_dict():
|
298 |
+
if k not in load_keys:
|
299 |
+
load_dict[k] = model.state_dict()[k]
|
300 |
+
model.load_state_dict(load_dict)
|
301 |
+
if args.vocab_size_delta > 0:
|
302 |
+
# model.cuda()
|
303 |
+
model.resize_emb(args.vocab_size + args.vocab_size_delta)
|
304 |
+
args.vocab_size = args.vocab_size + args.vocab_size_delta
|
305 |
+
|
306 |
+
trainer = Trainer.from_argparse_args(
|
307 |
+
args,
|
308 |
+
callbacks=[train_callback(args)],
|
309 |
+
)
|
310 |
+
if "deepspeed" in args.strategy:
|
311 |
+
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
312 |
+
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
313 |
+
|
314 |
+
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
315 |
+
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
316 |
+
|
317 |
+
trainer.fit(model, data_loader)
|