Cyleux commited on
Commit
8cea444
1 Parent(s): 688e796

Upload 18 files

Browse files
Files changed (18) hide show
  1. 20B_tokenizer.json +0 -0
  2. chat_kivy.py +130 -0
  3. cuda/wkv_cuda.cu +133 -0
  4. cuda/wkv_cuda_bf16.cu +132 -0
  5. cuda/wkv_op.cpp +21 -0
  6. cuda/wkv_op_bf16.cpp +25 -0
  7. run.py +223 -0
  8. src/__init__.py +0 -0
  9. src/binidx.py +269 -0
  10. src/dataset.py +245 -0
  11. src/model.py +610 -0
  12. src/model_img.py +446 -0
  13. src/model_run.py +233 -0
  14. src/trainer.py +192 -0
  15. src/utils.py +130 -0
  16. train.py +350 -0
  17. verify.py +104 -0
  18. zrwkv-37fifth.pth +3 -0
20B_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
chat_kivy.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print('Loading...')
2
+ from src.model_run import RWKV_RNN
3
+ import numpy as np
4
+ import os, copy, types, gc, sys
5
+ import torch
6
+ from src.utils import TOKENIZER
7
+
8
+ torch.backends.cudnn.benchmark = False
9
+ torch.backends.cudnn.allow_tf32 = False
10
+ torch.backends.cuda.matmul.allow_tf32 = False
11
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
12
+
13
+ WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"]
14
+ UNKNOWN_CHAR = None
15
+ tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
16
+
17
+ args = types.SimpleNamespace()
18
+ args.RUN_DEVICE = "cpu"
19
+ args.FLOAT_MODE = "fp32"
20
+ args.vocab_size = 50277
21
+ args.MODEL_NAME = 'zrwkv-37fifth'
22
+ args.n_layer = 12
23
+ args.n_embd = 768
24
+ args.ctx_len = 1024
25
+
26
+ user = "User"
27
+ bot = "Daniel"
28
+ interface = ":"
29
+
30
+ os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
31
+ MODEL_NAME = args.MODEL_NAME
32
+
33
+ print(f'loading... {MODEL_NAME}')
34
+ model = RWKV_RNN(args)
35
+
36
+ model_tokens = []
37
+ current_state = None
38
+
39
+ def run_rnn(tokens, newline_adj = 0):
40
+ global model_tokens, current_state
41
+ for i in range(len(tokens)):
42
+ model_tokens += [int(tokens[i])]
43
+ if i == len(tokens) - 1:
44
+ out, current_state = model.forward(model_tokens, current_state)
45
+ else:
46
+ current_state = model.forward(model_tokens, current_state, preprocess_only = True)
47
+
48
+ out[0] = -999999999
49
+ out[187] += newline_adj
50
+ return out
51
+
52
+ all_state = {}
53
+ def save_all_stat(name, last_out):
54
+ all_state[name] = {}
55
+ all_state[name]['out'] = last_out
56
+ all_state[name]['rnn'] = copy.deepcopy(current_state)
57
+ all_state[name]['token'] = copy.deepcopy(model_tokens)
58
+
59
+ def load_all_stat(name):
60
+ global model_tokens, current_state
61
+ current_state = copy.deepcopy(all_state[name]['rnn'])
62
+ model_tokens = copy.deepcopy(all_state[name]['token'])
63
+ return all_state[name]['out']
64
+
65
+ print(f'\nRun prompt...')
66
+
67
+ out = ""
68
+ gc.collect()
69
+
70
+ save_all_stat('chat_init', out)
71
+ save_all_stat('chat', out) # ensure that 'chat' key is added to all_state
72
+
73
+ print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')
74
+
75
+
76
+ def reply_msg(msg):
77
+ print(f'{bot}{interface} {msg}\n')
78
+
79
+ def on_message(message):
80
+ global model_tokens, current_state
81
+
82
+ msg = message.replace('\\n','\n').strip()
83
+ if len(msg) > 10000:
84
+ reply_msg('your message is too long (max 1000 tokens)')
85
+ return
86
+
87
+ out = load_all_stat('chat')
88
+ new = f"{user}{interface} {msg}\n{bot}{interface}"
89
+ out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999)
90
+ save_all_stat('chat_pre', out)
91
+
92
+ begin = len(model_tokens)
93
+ out_last = begin
94
+ print(f'{bot}{interface}', end='', flush=True)
95
+ for i in range(8000):
96
+ token = tokenizer.sample_logits(
97
+ out,
98
+ model_tokens,
99
+ args.ctx_len,
100
+ temperature=1.0,
101
+ top_p_usual=0.85,
102
+ top_p_newline=0.85,
103
+ )
104
+ out = run_rnn([token], newline_adj=1)
105
+
106
+ xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
107
+ if '\ufffd' not in xxx and 'user' not in str(xxx).lower() and '\n' not in xxx and str(xxx) != ':' and str(xxx) != '\n\n' and len(str(xxx)) > 0:
108
+ print(xxx, end='', flush=True)
109
+ out_last = begin + i + 1
110
+ else:
111
+ print('\n', end='', flush=True)
112
+ out_last = begin + i + 1
113
+
114
+ send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
115
+ if '\ufffd' in send_msg or send_msg.endswith(f'{user}{interface}') or send_msg.endswith(f'{bot}{interface}') or '\n' in send_msg:
116
+ send_msg = send_msg.strip()
117
+ send_msg = send_msg.replace(f'{user}{interface}', '')
118
+ send_msg = send_msg.replace(f'{bot}{interface}', '')
119
+ send_msg = send_msg.replace('\n', '')
120
+ break
121
+ save_all_stat('chat', out)
122
+
123
+ print('Start chatting with Daniel!')
124
+
125
+ while True:
126
+ msg = input(f'{user}{interface} ')
127
+ if len(msg.strip()) > 0:
128
+ on_message(msg)
129
+ else:
130
+ print('Error: please say something')
cuda/wkv_cuda.cu ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+
4
+ #define MIN_VALUE (-1e38)
5
+
6
+ template <typename F>
7
+ __global__ void kernel_forward(const int B, const int T, const int C,
8
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
9
+ F *__restrict__ const _y) {
10
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
11
+ const int _b = idx / C;
12
+ const int _c = idx % C;
13
+ const int _offset = _b * T * C + _c;
14
+
15
+ F u = _u[_c];
16
+ F w = _w[_c];
17
+ const F *__restrict__ const k = _k + _offset;
18
+ const F *__restrict__ const v = _v + _offset;
19
+ F *__restrict__ const y = _y + _offset;
20
+
21
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
22
+ F aa = 0, bb = 0, pp = MIN_VALUE;
23
+ for (int i = 0; i < T; i++) {
24
+ const int ii = i * C;
25
+ const F kk = k[ii];
26
+ const F vv = v[ii];
27
+
28
+ F ww = u + kk;
29
+ F p = max(pp, ww);
30
+ F e1 = exp(pp - p);
31
+ F e2 = exp(ww - p);
32
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
33
+
34
+ ww = w + pp;
35
+ p = max(ww, kk);
36
+ e1 = exp(ww - p);
37
+ e2 = exp(kk - p);
38
+ aa = e1 * aa + e2 * vv;
39
+ bb = e1 * bb + e2;
40
+ pp = p;
41
+ }
42
+ }
43
+
44
+ template <typename F>
45
+ __global__ void kernel_backward(const int B, const int T, const int C,
46
+ const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
47
+ const F *__restrict__ const _y, const F *__restrict__ const _gy,
48
+ F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
49
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
50
+ const int _b = idx / C;
51
+ const int _c = idx % C;
52
+ const int _offset = _b * T * C + _c;
53
+
54
+ F u = _u[_c];
55
+ F w = _w[_c];
56
+ const F *__restrict__ const k = _k + _offset;
57
+ const F *__restrict__ const v = _v + _offset;
58
+ const F *__restrict__ const y = _y + _offset;
59
+ const F *__restrict__ const gy = _gy + _offset;
60
+ F *__restrict__ const gk = _gk + _offset;
61
+ F *__restrict__ const gv = _gv + _offset;
62
+
63
+ F q[Tmax], r[Tmax];
64
+
65
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
66
+ for (int i = 0; i < T; i++) {
67
+ const int ii = i * C;
68
+ const F kk = k[ii];
69
+ const F vv = v[ii];
70
+ const F yy = y[ii];
71
+
72
+ F ww = u + kk;
73
+ F p = max(pp, ww);
74
+ F e1 = exp(pp - p);
75
+ F e2 = exp(ww - p);
76
+ const F qq = gy[ii] / (e1 * bb + e2);
77
+ gw += (ga - gb * yy) * e1 * qq;
78
+ gu += (vv - yy) * e2 * qq;
79
+ q[i] = qq;
80
+ r[i] = ww - p;
81
+
82
+ ww = w + pp;
83
+ p = max(ww, kk);
84
+ e1 = exp(ww - p);
85
+ e2 = exp(kk - p);
86
+ ga = e1 * (aa + ga);
87
+ gb = e1 * (bb + gb);
88
+ aa = e1 * aa + e2 * vv;
89
+ bb = e1 * bb + e2;
90
+ pp = p;
91
+ }
92
+ const int _offsetBC = _b * C + _c;
93
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
94
+ _gu[_offsetBC] = gu;
95
+
96
+ aa = 0, bb = 0, pp = MIN_VALUE;
97
+ for (int i = T - 1; i >= 0; i--) {
98
+ const int ii = i * C;
99
+ const F kk = k[ii];
100
+ const F vv = v[ii];
101
+ const F yy = y[ii];
102
+ const F qq = q[i];
103
+ const F rr = r[i];
104
+
105
+ F e1 = qq * exp(rr);
106
+ F e2 = exp(kk + pp);
107
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
108
+ gv[ii] = e1 + e2 * aa;
109
+
110
+ const F ww = w + pp;
111
+ const F www = rr - u - kk;
112
+ const F p = max(ww, www);
113
+ e1 = exp(ww - p);
114
+ e2 = qq * exp(www - p);
115
+ aa = e1 * aa + e2;
116
+ bb = e1 * bb - e2 * yy;
117
+ pp = p;
118
+ }
119
+ }
120
+
121
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
122
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
123
+ assert(B * C % threadsPerBlock.x == 0);
124
+ dim3 numBlocks(B * C / threadsPerBlock.x);
125
+ kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
126
+ }
127
+
128
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
129
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
130
+ assert(B * C % threadsPerBlock.x == 0);
131
+ dim3 numBlocks(B * C / threadsPerBlock.x);
132
+ kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
133
+ }
cuda/wkv_cuda_bf16.cu ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ #define MIN_VALUE (-1e38)
5
+ typedef at::BFloat16 bf16;
6
+
7
+ __global__ void kernel_forward(const int B, const int T, const int C,
8
+ const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
9
+ bf16 *__restrict__ const _y) {
10
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
11
+ const int _b = idx / C;
12
+ const int _c = idx % C;
13
+ const int _offset = _b * T * C + _c;
14
+
15
+ float u = float(_u[_c]);
16
+ float w = _w[_c];
17
+ const bf16 *__restrict__ const k = _k + _offset;
18
+ const bf16 *__restrict__ const v = _v + _offset;
19
+ bf16 *__restrict__ const y = _y + _offset;
20
+
21
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
22
+ float aa = 0, bb = 0, pp = MIN_VALUE;
23
+ for (int i = 0; i < T; i++) {
24
+ const int ii = i * C;
25
+ const float kk = float(k[ii]);
26
+ const float vv = float(v[ii]);
27
+
28
+ float ww = u + kk;
29
+ float p = max(pp, ww);
30
+ float e1 = exp(pp - p);
31
+ float e2 = exp(ww - p);
32
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
33
+
34
+ ww = w + pp;
35
+ p = max(ww, kk);
36
+ e1 = exp(ww - p);
37
+ e2 = exp(kk - p);
38
+ aa = e1 * aa + e2 * vv;
39
+ bb = e1 * bb + e2;
40
+ pp = p;
41
+ }
42
+ }
43
+
44
+ __global__ void kernel_backward(const int B, const int T, const int C,
45
+ const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
46
+ const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
47
+ bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
48
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
49
+ const int _b = idx / C;
50
+ const int _c = idx % C;
51
+ const int _offset = _b * T * C + _c;
52
+
53
+ float u = float(_u[_c]);
54
+ float w = _w[_c];
55
+ const bf16 *__restrict__ const k = _k + _offset;
56
+ const bf16 *__restrict__ const v = _v + _offset;
57
+ const bf16 *__restrict__ const y = _y + _offset;
58
+ const bf16 *__restrict__ const gy = _gy + _offset;
59
+ bf16 *__restrict__ const gk = _gk + _offset;
60
+ bf16 *__restrict__ const gv = _gv + _offset;
61
+
62
+ float q[Tmax], r[Tmax];
63
+
64
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
65
+ for (int i = 0; i < T; i++) {
66
+ const int ii = i * C;
67
+ const float kk = float(k[ii]);
68
+ const float vv = float(v[ii]);
69
+ const float yy = float(y[ii]);
70
+
71
+ float ww = u + kk;
72
+ float p = max(pp, ww);
73
+ float e1 = exp(pp - p);
74
+ float e2 = exp(ww - p);
75
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
76
+ gw += (ga - gb * yy) * e1 * qq;
77
+ gu += (vv - yy) * e2 * qq;
78
+ q[i] = qq;
79
+ r[i] = ww - p;
80
+
81
+ ww = w + pp;
82
+ p = max(ww, kk);
83
+ e1 = exp(ww - p);
84
+ e2 = exp(kk - p);
85
+ ga = e1 * (aa + ga);
86
+ gb = e1 * (bb + gb);
87
+ aa = e1 * aa + e2 * vv;
88
+ bb = e1 * bb + e2;
89
+ pp = p;
90
+ }
91
+ const int _offsetBC = _b * C + _c;
92
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
93
+ _gu[_offsetBC] = bf16(gu);
94
+
95
+ aa = 0, bb = 0, pp = MIN_VALUE;
96
+ for (int i = T - 1; i >= 0; i--) {
97
+ const int ii = i * C;
98
+ const float kk = float(k[ii]);
99
+ const float vv = float(v[ii]);
100
+ const float yy = float(y[ii]);
101
+ const float qq = q[i];
102
+ const float rr = r[i];
103
+
104
+ float e1 = qq * exp(rr);
105
+ float e2 = exp(kk + pp);
106
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
107
+ gv[ii] = bf16(e1 + e2 * aa);
108
+
109
+ const float ww = w + pp;
110
+ const float www = rr - u - kk;
111
+ const float p = max(ww, www);
112
+ e1 = exp(ww - p);
113
+ e2 = qq * exp(www - p);
114
+ aa = e1 * aa + e2;
115
+ bb = e1 * bb - e2 * yy;
116
+ pp = p;
117
+ }
118
+ }
119
+
120
+ void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
121
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
122
+ assert(B * C % threadsPerBlock.x == 0);
123
+ dim3 numBlocks(B * C / threadsPerBlock.x);
124
+ kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
125
+ }
126
+
127
+ void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
128
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
129
+ assert(B * C % threadsPerBlock.x == 0);
130
+ dim3 numBlocks(B * C / threadsPerBlock.x);
131
+ kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
132
+ }
cuda/wkv_op.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
4
+ void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
5
+
6
+ void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
7
+ cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
8
+ }
9
+ void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
10
+ cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
11
+ }
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("forward", &forward, "wkv forward");
15
+ m.def("backward", &backward, "wkv backward");
16
+ }
17
+
18
+ TORCH_LIBRARY(wkv, m) {
19
+ m.def("forward", forward);
20
+ m.def("backward", backward);
21
+ }
cuda/wkv_op_bf16.cpp ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ typedef at::BFloat16 bf16;
4
+
5
+ void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
6
+ void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
7
+
8
+ void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
9
+ cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
10
+ }
11
+ void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
12
+ torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
13
+ cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
14
+ gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
15
+ }
16
+
17
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18
+ m.def("forward", &forward, "wkv forward");
19
+ m.def("backward", &backward, "wkv backward");
20
+ }
21
+
22
+ TORCH_LIBRARY(wkv, m) {
23
+ m.def("forward", forward);
24
+ m.def("backward", backward);
25
+ }
run.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import numpy as np
6
+ import math, os, sys, types, time, gc
7
+ import torch
8
+ from src.utils import TOKENIZER
9
+ try:
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
11
+ except:
12
+ pass
13
+ torch.backends.cudnn.benchmark = True
14
+ torch.backends.cudnn.allow_tf32 = True
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
17
+ args = types.SimpleNamespace()
18
+
19
+ ########################################################################################################
20
+ # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible)
21
+ ########################################################################################################
22
+
23
+ args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast)
24
+ args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU)
25
+
26
+ # if args.RUN_DEVICE == "cuda":
27
+ # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output
28
+ os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!!
29
+
30
+ TOKEN_MODE = "pile"
31
+ WORD_NAME = [
32
+ "20B_tokenizer.json",
33
+ "20B_tokenizer.json",
34
+ ] # [vocab, vocab] for Pile model
35
+ UNKNOWN_CHAR = None
36
+ vocab_size = 50277
37
+
38
+ # Download Pile models: https://huggingface.co/BlinkDL
39
+ # or, set MODEL_NAME to your fine-tuned model
40
+
41
+ # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023"
42
+ # n_layer = 12
43
+ # n_embd = 768
44
+ # ctx_len = 1024
45
+
46
+ # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066'
47
+ # n_layer = 24
48
+ # n_embd = 1024
49
+ # ctx_len = 1024
50
+
51
+ # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040'
52
+ # n_layer = 24
53
+ # n_embd = 2048
54
+ # ctx_len = 1024
55
+
56
+ # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023'
57
+ # n_layer = 32
58
+ # n_embd = 2560
59
+ # ctx_len = 1024
60
+
61
+ MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
62
+ n_layer = 32
63
+ n_embd = 4096
64
+ ctx_len = 1024
65
+
66
+ args.MODEL_NAME = MODEL_NAME
67
+ args.n_layer = n_layer
68
+ args.n_embd = n_embd
69
+ args.ctx_len = ctx_len
70
+ args.vocab_size = vocab_size
71
+ args.head_qk = 0
72
+ args.pre_ffn = 0
73
+ args.grad_cp = 0
74
+ args.my_pos_emb = 0
75
+ os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
76
+
77
+ ########################################################################################################
78
+ # Step 2: set prompt & sampling stuffs
79
+ ########################################################################################################
80
+
81
+ # context = 'A'
82
+ # context = "\nIn the"
83
+ # context = '\nSugar:'
84
+ context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
85
+
86
+ # context = "\n深圳是" # test Chinese
87
+ # context = "\n東京は" # test Japanese
88
+
89
+ # ###### A good prompt for Q&A ######
90
+ # context = '''
91
+ # Questions & Helpful Answers
92
+ # Ask Research Experts
93
+ # Question:
94
+ # Can penguins fly?
95
+
96
+ # Full Answer:
97
+ # '''
98
+
99
+ # ###### A good prompt for chatbot ######
100
+ # context = '''
101
+ # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.
102
+
103
+ # User: who is president of usa?
104
+
105
+ # Bot: It’s Joe Biden; he was sworn in earlier this year.
106
+
107
+ # User: french revolution what year
108
+
109
+ # Bot: It started in 1789, but it lasted 10 years until 1799.
110
+
111
+ # User: guess i marry who ?
112
+
113
+ # Bot: Only if you tell me more about yourself - what are your interests?
114
+
115
+ # User: wat is lhc
116
+
117
+ # Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
118
+
119
+ # User:''' # type your question here
120
+
121
+ NUM_TRIALS = 999
122
+ LENGTH_PER_TRIAL = 333
123
+
124
+ TEMPERATURE = 1.0
125
+ top_p = 0.8
126
+ top_p_newline = 0.9 # only used in TOKEN_MODE = char
127
+
128
+ DEBUG_DEBUG = False # True False --> show softmax output
129
+
130
+ ########################################################################################################
131
+
132
+ from src.model_run import RWKV_RNN
133
+
134
+ model = RWKV_RNN(args)
135
+
136
+ out, _ = model.forward([187], None)
137
+ # print(out)
138
+ gc.collect()
139
+ torch.cuda.empty_cache()
140
+
141
+ # input(0)
142
+
143
+ tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
144
+ if TOKEN_MODE == "pile":
145
+ assert tokenizer.tokenizer.decode([187]) == '\n'
146
+
147
+ ########################################################################################################
148
+
149
+ if tokenizer.charMode:
150
+ context = tokenizer.refine_context(context)
151
+ ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
152
+ else:
153
+ ctx = tokenizer.tokenizer.encode(context)
154
+ src_len = len(ctx)
155
+ src_ctx = ctx.copy()
156
+
157
+
158
+ time_slot = {}
159
+ time_ref = time.time_ns()
160
+
161
+ def record_time(name):
162
+ if name not in time_slot:
163
+ time_slot[name] = 1e20
164
+ tt = (time.time_ns() - time_ref) / 1e9
165
+ if tt < time_slot[name]:
166
+ time_slot[name] = tt
167
+
168
+ init_state = None
169
+ init_out = None
170
+ state = None
171
+ out = None
172
+
173
+ for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
174
+
175
+ time_ref = time.time_ns()
176
+ ctx = src_ctx.copy()
177
+
178
+ if TRIAL == 0:
179
+ for i in range(src_len):
180
+ x = ctx[: i + 1]
181
+ if i == src_len - 1:
182
+ init_out, init_state = model.forward(x, init_state)
183
+ else:
184
+ init_state = model.forward(x, init_state, preprocess_only=True)
185
+ gc.collect()
186
+ torch.cuda.empty_cache()
187
+
188
+ record_time('preprocess')
189
+ out_last = src_len
190
+ for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)):
191
+ x = ctx[: i + 1]
192
+ x = x[-ctx_len:]
193
+
194
+ if i == src_len:
195
+ out = init_out.clone()
196
+ state = init_state.clone()
197
+ else:
198
+ out, state = model.forward(x, state)
199
+ if DEBUG_DEBUG:
200
+ if TOKEN_MODE == "pile":
201
+ out[0] = -999999999 # disable <|endoftext|>
202
+
203
+ ttt = tokenizer.sample_logits(
204
+ out,
205
+ x,
206
+ ctx_len,
207
+ temperature=TEMPERATURE,
208
+ top_p_usual=top_p,
209
+ top_p_newline=top_p_newline,
210
+ )
211
+ ctx += [ttt]
212
+
213
+ if tokenizer.charMode:
214
+ char = tokenizer.itos[ttt]
215
+ else:
216
+ char = tokenizer.tokenizer.decode(ctx[out_last:])
217
+ if '\ufffd' not in char: # is valid utf8 string?
218
+ out_last = i+1
219
+
220
+ record_time('total')
221
+ # print(f'\n\n{time_slot}\n\n')
222
+
223
+
src/__init__.py ADDED
File without changes
src/binidx.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib2to3.pgen2 import token
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import shutil
6
+ import struct
7
+ from functools import lru_cache
8
+ from itertools import accumulate
9
+
10
+ def print_rank_0(*message):
11
+ pass
12
+ # """If distributed is initialized print only on rank 0."""
13
+ # if torch.distributed.is_initialized():
14
+ # if torch.distributed.get_rank() == 0:
15
+ # print(*message, flush=True)
16
+ # else:
17
+ # print(*message, flush=True)
18
+
19
+ def _warmup_mmap_file(path):
20
+ pass
21
+ # with open(path, "rb") as stream:
22
+ # while stream.read(100 * 1024 * 1024):
23
+ # pass
24
+
25
+ dtypes = {
26
+ 1: np.uint8,
27
+ 2: np.int8,
28
+ 3: np.int16,
29
+ 4: np.int32,
30
+ 5: np.int64,
31
+ 6: float,
32
+ 7: np.double,
33
+ 8: np.uint16,
34
+ }
35
+
36
+ def code(dtype):
37
+ for k in dtypes.keys():
38
+ if dtypes[k] == dtype:
39
+ return k
40
+ raise ValueError(dtype)
41
+
42
+ def index_file_path(prefix_path):
43
+ return prefix_path + ".idx"
44
+
45
+ def data_file_path(prefix_path):
46
+ return prefix_path + ".bin"
47
+
48
+ class MMapIndexedDataset(torch.utils.data.Dataset):
49
+ class Index(object):
50
+ _HDR_MAGIC = b"MMIDIDX\x00\x00"
51
+
52
+ @classmethod
53
+ def writer(cls, path, dtype):
54
+ class _Writer(object):
55
+ def __enter__(self):
56
+ self._file = open(path, "wb")
57
+
58
+ # Write Magic string so we can check the file format then opening it again.
59
+ self._file.write(cls._HDR_MAGIC)
60
+ # Write version number
61
+ # Little endian unsigned 64 Bit integer
62
+ self._file.write(struct.pack("<Q", 1))
63
+ # Little endian unsigned 8 Bit integer
64
+ self._file.write(struct.pack("<B", code(dtype)))
65
+
66
+ return self
67
+
68
+ @staticmethod
69
+ def _get_pointers(sizes):
70
+ dtype_size = dtype().itemsize
71
+ address = 0
72
+ pointers = []
73
+
74
+ for size in sizes:
75
+ pointers.append(address)
76
+ address += size * dtype_size
77
+
78
+ return pointers
79
+
80
+ def write(self, sizes, doc_idx):
81
+ pointers = self._get_pointers(sizes)
82
+
83
+ # Little endian unsigned 64 Bit integer
84
+ self._file.write(struct.pack("<Q", len(sizes)))
85
+ # Little endian unsigned 64 Bit integer
86
+ self._file.write(struct.pack("<Q", len(doc_idx)))
87
+
88
+ sizes = np.array(sizes, dtype=np.int32)
89
+ self._file.write(sizes.tobytes(order="C"))
90
+ del sizes
91
+
92
+ pointers = np.array(pointers, dtype=np.int64)
93
+ self._file.write(pointers.tobytes(order="C"))
94
+ del pointers
95
+
96
+ doc_idx = np.array(doc_idx, dtype=np.int64)
97
+ self._file.write(doc_idx.tobytes(order="C"))
98
+
99
+ def __exit__(self, exc_type, exc_val, exc_tb):
100
+ self._file.close()
101
+
102
+ return _Writer()
103
+
104
+ def __init__(self, path, skip_warmup=False):
105
+ with open(path, "rb") as stream:
106
+ magic_test = stream.read(9)
107
+ assert self._HDR_MAGIC == magic_test, (
108
+ "Index file doesn't match expected format. "
109
+ "Make sure that --dataset-impl is configured properly."
110
+ )
111
+ # Little endian unsigned 64 Bit integer
112
+ version = struct.unpack("<Q", stream.read(8))
113
+ assert (1,) == version
114
+
115
+ # Little endian unsigned 8 Bit integer
116
+ (dtype_code,) = struct.unpack("<B", stream.read(1))
117
+ self._dtype = dtypes[dtype_code]
118
+ self._dtype_size = self._dtype().itemsize
119
+
120
+ self._len = struct.unpack("<Q", stream.read(8))[0]
121
+ self._doc_count = struct.unpack("<Q", stream.read(8))[0]
122
+ offset = stream.tell()
123
+
124
+ if not skip_warmup:
125
+ print_rank_0(" warming up index mmap file...")
126
+ _warmup_mmap_file(path)
127
+
128
+ self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
129
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
130
+ print_rank_0(" reading sizes...")
131
+ self._sizes = np.frombuffer(
132
+ self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
133
+ )
134
+ print_rank_0(" reading pointers...")
135
+ self._pointers = np.frombuffer(
136
+ self._bin_buffer,
137
+ dtype=np.int64,
138
+ count=self._len,
139
+ offset=offset + self._sizes.nbytes,
140
+ )
141
+ print_rank_0(" reading document index...")
142
+ self._doc_idx = np.frombuffer(
143
+ self._bin_buffer,
144
+ dtype=np.int64,
145
+ count=self._doc_count,
146
+ offset=offset + self._sizes.nbytes + self._pointers.nbytes,
147
+ )
148
+
149
+ def __del__(self):
150
+ self._bin_buffer_mmap._mmap.close()
151
+ del self._bin_buffer_mmap
152
+
153
+ @property
154
+ def dtype(self):
155
+ return self._dtype
156
+
157
+ @property
158
+ def sizes(self):
159
+ return self._sizes
160
+
161
+ @property
162
+ def doc_idx(self):
163
+ return self._doc_idx
164
+
165
+ @lru_cache(maxsize=8)
166
+ def __getitem__(self, i):
167
+ return self._pointers[i], self._sizes[i]
168
+
169
+ def __len__(self):
170
+ return self._len
171
+
172
+ def __init__(self, path, skip_warmup=False):
173
+ super().__init__()
174
+
175
+ self._path = None
176
+ self._index = None
177
+ self._bin_buffer = None
178
+
179
+ self._do_init(path, skip_warmup)
180
+
181
+ def __getstate__(self):
182
+ return self._path
183
+
184
+ def __setstate__(self, state):
185
+ self._do_init(state)
186
+
187
+ def _do_init(self, path, skip_warmup):
188
+ self._path = path
189
+ self._index = self.Index(index_file_path(self._path), skip_warmup)
190
+
191
+ if not skip_warmup:
192
+ print_rank_0(" warming up data mmap file...")
193
+ _warmup_mmap_file(data_file_path(self._path))
194
+ print_rank_0(" creating numpy buffer of mmap...")
195
+ self._bin_buffer_mmap = np.memmap(
196
+ data_file_path(self._path), mode="r", order="C"
197
+ )
198
+ print_rank_0(" creating memory view of numpy buffer...")
199
+ self._bin_buffer = memoryview(self._bin_buffer_mmap)
200
+
201
+ def __del__(self):
202
+ self._bin_buffer_mmap._mmap.close()
203
+ del self._bin_buffer_mmap
204
+ del self._index
205
+
206
+ def __len__(self):
207
+ return len(self._index)
208
+
209
+ # @lru_cache(maxsize=8)
210
+ def __getitem__(self, idx):
211
+ if isinstance(idx, int):
212
+ ptr, size = self._index[idx]
213
+ np_array = np.frombuffer(
214
+ self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
215
+ )
216
+ return np_array
217
+ elif isinstance(idx, slice):
218
+ start, stop, step = idx.indices(len(self))
219
+ if step != 1:
220
+ raise ValueError(
221
+ "Slices into indexed_dataset must be contiguous")
222
+ ptr = self._index._pointers[start]
223
+ sizes = self._index._sizes[idx]
224
+ offsets = list(accumulate(sizes))
225
+ total_size = sum(sizes)
226
+ np_array = np.frombuffer(
227
+ self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
228
+ )
229
+ sents = np.split(np_array, offsets[:-1])
230
+ return sents
231
+
232
+ def get(self, idx, offset=0, length=None):
233
+ """Retrieves a single item from the dataset with the option to only
234
+ return a portion of the item.
235
+
236
+ get(idx) is the same as [idx] but get() does not support slicing.
237
+ """
238
+ ptr, size = self._index[idx]
239
+ if length is None:
240
+ length = size - offset
241
+ ptr += offset * np.dtype(self._index.dtype).itemsize
242
+ np_array = np.frombuffer(
243
+ self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
244
+ )
245
+ return np_array
246
+
247
+ @property
248
+ def sizes(self):
249
+ return self._index.sizes
250
+
251
+ @property
252
+ def doc_idx(self):
253
+ return self._index.doc_idx
254
+
255
+ def get_doc_idx(self):
256
+ return self._index._doc_idx
257
+
258
+ def set_doc_idx(self, doc_idx_):
259
+ self._index._doc_idx = doc_idx_
260
+
261
+ @property
262
+ def supports_prefetch(self):
263
+ return False
264
+
265
+ @staticmethod
266
+ def exists(path):
267
+ return os.path.exists(index_file_path(path)) and os.path.exists(
268
+ data_file_path(path)
269
+ )
src/dataset.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import json, math, random, os, sys
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from pytorch_lightning.utilities import rank_zero_info
10
+ from .binidx import MMapIndexedDataset
11
+ from .utils import MaybeIsPrime
12
+
13
+
14
+ class MyDataset(Dataset):
15
+ def __init__(self, args):
16
+ self.args = args
17
+
18
+ if args.data_type == "binidx":
19
+ self.vocab_size = args.vocab_size
20
+ rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
21
+
22
+ if args.my_pile_version == 1:
23
+ self.data = MMapIndexedDataset(args.data_file)
24
+ self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
25
+ rank_zero_info(f"Data has {self.data_size} tokens.")
26
+ else:
27
+ data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
28
+ data_list = [i.strip().split(' ') for i in data_list]
29
+ self.data = []
30
+ self.data_size = int(data_list[-1][-1])
31
+ rank_zero_info(f"Data has {self.data_size} chunks.")
32
+ for d in data_list:
33
+ data = MMapIndexedDataset(d[0])
34
+ data_size = len(data._bin_buffer) // data._index._dtype_size
35
+ assert (data_size - args.ctx_len) == int(d[1])
36
+ self.data += [[int(d[-1]), int(d[1]), data]]
37
+ # rank_zero_info(self.data)
38
+
39
+ if args.my_qa_mask > 0:
40
+ # self.data_pile = MMapIndexedDataset('/fsx/pile/pile_20B_tokenizer_text_document')
41
+ self.data_pile = MMapIndexedDataset('/fsx/pile_deduped/pile_0.87_deduped_text_document')
42
+ self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
43
+ else:
44
+ self.data_pile = None
45
+ self.data_pile_size = 0
46
+
47
+ if args.my_pile_stage > 0:
48
+ # assert self.data_size == 332115325534 and self.vocab_size == 50277
49
+ self.samples_per_epoch = args.epoch_steps * args.real_bsz
50
+ assert self.samples_per_epoch == 40320
51
+ rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
52
+ dataset_slot = self.data_size // args.ctx_len
53
+ if args.my_pile_stage != 4:
54
+ assert MaybeIsPrime(args.magic_prime)
55
+ assert args.magic_prime % 3 == 2
56
+ assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
57
+ elif args.data_type == "numpy":
58
+ self.data = np.load(args.data_file).astype("int")
59
+ self.vocab_size = args.vocab_size
60
+ rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
61
+ self.data_size = len(self.data)
62
+ rank_zero_info(f"Data has {self.data_size} tokens.")
63
+ elif args.data_type == "uint16":
64
+ self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
65
+ self.vocab_size = args.vocab_size
66
+ rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
67
+ self.data_size = self.data.shape[0]
68
+ rank_zero_info(f"Data has {self.data_size} samples.")
69
+ elif args.data_type == "wds_img":
70
+ self.vocab_size = -1
71
+ self.data_size = -1
72
+ self.data = None
73
+ self.error_count = 0
74
+ else:
75
+ if args.data_type == "dummy":
76
+ rank_zero_info("Building dummy data...")
77
+ self.data = ""
78
+ for i in range(100000):
79
+ aa = (i) % 10000
80
+ bb = (i * i) % 10000
81
+ cc = aa + bb
82
+ self.data += f".{aa}+{bb}={cc}."
83
+ else:
84
+ self.data = open(args.data_file, "r", encoding=args.data_type).read()
85
+ rank_zero_info("Building token list...")
86
+ unique = sorted(list(set(self.data)))
87
+ self.vocab_size = len(unique)
88
+ # rank_zero_info()
89
+ # for u in unique:
90
+ # print(u, end=' ')
91
+ # rank_zero_info('\n\n')
92
+ xx = 0
93
+ xxObj = {}
94
+ for u in unique:
95
+ xxObj[xx] = u
96
+ xx += 1
97
+ with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
98
+ vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
99
+ self.data_size = len(self.data)
100
+ rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
101
+ self.stoi = {ch: i for i, ch in enumerate(unique)}
102
+ self.itos = {i: ch for i, ch in enumerate(unique)}
103
+
104
+ def __len__(self):
105
+ return self.args.epoch_steps * self.args.micro_bsz
106
+
107
+ def __getitem__(self, idx):
108
+ args = self.args
109
+ rank = self.global_rank
110
+ epoch = self.real_epoch
111
+ world_size = self.world_size
112
+ # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
113
+
114
+ if args.data_type == "wds_img":
115
+ def init_wds(self, bias=0):
116
+ def identity(x):
117
+ return x
118
+ import webdataset as wds
119
+ import torchvision.transforms as transforms
120
+ # img_transform = transforms.Compose(
121
+ # [transforms.CenterCrop(256)]
122
+ # )
123
+ img_transform = transforms.Compose([
124
+ transforms.CenterCrop(512),
125
+ transforms.Resize((args.my_img_size))
126
+ ])
127
+ self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
128
+ for pp in self.data_raw.pipeline:
129
+ if 'Resampled' in str(pp):
130
+ pp.deterministic = True
131
+ def worker_seed():
132
+ return rank*100000+epoch+bias*1e9
133
+ pp.worker_seed = worker_seed
134
+ self.data = iter(self.data_raw)
135
+ # print(f"WebDataset loaded for rank {rank} epoch {epoch}")
136
+ if self.data == None:
137
+ init_wds(self)
138
+ trial = 0
139
+ while trial < 10:
140
+ try:
141
+ dd = next(self.data) # jpg, json, txt
142
+ break
143
+ except:
144
+ print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
145
+ self.error_count += 1
146
+ init_wds(self, self.error_count)
147
+ trial += 1
148
+ pass
149
+ # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
150
+ # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
151
+ # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
152
+ return dd[0], dd[2]
153
+ else:
154
+ if args.data_type == "uint16":
155
+ i = np.random.randint(0, self.data_size-1)
156
+ dix = self.data[i]
157
+ x = torch.tensor(dix[:-1], dtype=torch.long)
158
+ y = torch.tensor(dix[1:], dtype=torch.long)
159
+ else:
160
+ ctx_len = args.ctx_len
161
+ req_len = ctx_len + 1
162
+ magic_prime = args.magic_prime
163
+ data = self.data
164
+
165
+ if args.my_pile_stage > 0:
166
+ ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
167
+
168
+ if args.my_qa_mask > 0:
169
+ ii_orig = ii
170
+ if ii % 2 == 0:
171
+ ii = -1
172
+ data = self.data_pile
173
+ else:
174
+ ii = ii // 2
175
+ if data == self.data_pile:
176
+ i = np.random.randint(0, self.data_pile_size - req_len)
177
+ else:
178
+ if args.my_pile_stage == 4 or ii < args.my_random_steps:
179
+ # cheat: pick a random spot in dataset
180
+ if args.my_pile_version == 1:
181
+ i = np.random.randint(0, self.data_size - req_len)
182
+ else:
183
+ i = np.random.randint(0, self.data_size)
184
+ else:
185
+ ii = ii - args.my_random_steps
186
+ factor = (math.sqrt(5) - 1) / 2
187
+ factor = int(magic_prime * factor)
188
+ i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
189
+ i = i + args.my_pile_shift
190
+ # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
191
+ else:
192
+ # cheat: pick a random spot in dataset
193
+ i = np.random.randint(0, self.data_size - req_len)
194
+
195
+ if args.data_type == "binidx":
196
+ if args.my_pile_version == 1:
197
+ dix = data.get(idx=0, offset=i, length=req_len).astype(int)
198
+ else:
199
+ # self.data : cutoff, chunk_count, data
200
+ for j in range(len(data)):
201
+ if i < data[j][0]:
202
+ ii = i
203
+ i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
204
+ dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
205
+ # print(ii, j, i)
206
+ break
207
+ elif args.data_type == "numpy":
208
+ dix = data[i : i + req_len]
209
+ else:
210
+ dix = [self.stoi[s] for s in data[i : i + req_len]]
211
+
212
+ if args.my_qa_mask == 1:
213
+ if data == self.data_pile:
214
+ z = [1] * ctx_len
215
+ else:
216
+ z = [0] * ctx_len
217
+ z_sum = 0
218
+ isGood = False
219
+ for i in range(3, ctx_len):
220
+ if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
221
+ isGood = True
222
+ if dix[i] == 0:
223
+ isGood = False
224
+ if isGood:
225
+ z[i] = 1
226
+ z_sum += 1
227
+ if z_sum == 0:
228
+ z = [1] * ctx_len
229
+ i = np.random.randint(0, self.data_pile_size - req_len)
230
+ dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
231
+ z = torch.tensor(z, dtype=torch.bfloat16)
232
+
233
+ x = torch.tensor(dix[:-1], dtype=torch.long)
234
+ y = torch.tensor(dix[1:], dtype=torch.long)
235
+
236
+ # if ii_orig < 50:
237
+ # # if rank == 1:
238
+ # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
239
+ # else:
240
+ # exit(0)
241
+
242
+ if args.my_qa_mask == 1:
243
+ return x, y, z
244
+
245
+ return x, y
src/model.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import os, math, gc, importlib
6
+ import torch
7
+ # torch._C._jit_set_profiling_executor(True)
8
+ # torch._C._jit_set_profiling_mode(True)
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import pytorch_lightning as pl
12
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
13
+ from pytorch_lightning.strategies import DeepSpeedStrategy
14
+ if importlib.util.find_spec('deepspeed'):
15
+ import deepspeed
16
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
17
+
18
+ # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
19
+
20
+ try:
21
+ print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
22
+ except:
23
+ os.environ["RWKV_MY_TESTING"] = ''
24
+
25
+ def __nop(ob):
26
+ return ob
27
+
28
+
29
+ MyModule = nn.Module
30
+ MyFunction = __nop
31
+ if os.environ["RWKV_JIT_ON"] == "1":
32
+ MyModule = torch.jit.ScriptModule
33
+ MyFunction = torch.jit.script_method
34
+
35
+
36
+ ########################################################################################################
37
+ # CUDA Kernel
38
+ ########################################################################################################
39
+
40
+ T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
41
+ # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
42
+
43
+ from torch.utils.cpp_extension import load
44
+
45
+ if os.environ["RWKV_FLOAT_MODE"] == "bf16":
46
+ wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
47
+ class WKV(torch.autograd.Function):
48
+ @staticmethod
49
+ def forward(ctx, B, T, C, w, u, k, v):
50
+ ctx.B = B
51
+ ctx.T = T
52
+ ctx.C = C
53
+ assert T <= T_MAX
54
+ assert B * C % min(C, 32) == 0
55
+ w = -torch.exp(w.float().contiguous())
56
+ u = u.contiguous()
57
+ k = k.contiguous()
58
+ v = v.contiguous()
59
+ y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
60
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
61
+ ctx.save_for_backward(w, u, k, v, y)
62
+ return y
63
+ @staticmethod
64
+ def backward(ctx, gy):
65
+ B = ctx.B
66
+ T = ctx.T
67
+ C = ctx.C
68
+ assert T <= T_MAX
69
+ assert B * C % min(C, 32) == 0
70
+ w, u, k, v, y = ctx.saved_tensors
71
+ gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
72
+ gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
73
+ gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
74
+ gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
75
+ wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
76
+ gw = torch.sum(gw, dim=0)
77
+ gu = torch.sum(gu, dim=0)
78
+ return (None, None, None, gw, gu, gk, gv)
79
+ else:
80
+ wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
81
+ class WKV(torch.autograd.Function):
82
+ @staticmethod
83
+ def forward(ctx, B, T, C, w, u, k, v):
84
+ ctx.B = B
85
+ ctx.T = T
86
+ ctx.C = C
87
+ assert T <= T_MAX
88
+ assert B * C % min(C, 32) == 0
89
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
90
+ w = -torch.exp(w.contiguous())
91
+ u = u.contiguous()
92
+ k = k.contiguous()
93
+ v = v.contiguous()
94
+ else:
95
+ w = -torch.exp(w.float().contiguous())
96
+ u = u.float().contiguous()
97
+ k = k.float().contiguous()
98
+ v = v.float().contiguous()
99
+ y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
100
+ wkv_cuda.forward(B, T, C, w, u, k, v, y)
101
+ ctx.save_for_backward(w, u, k, v, y)
102
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
103
+ return y
104
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
105
+ return y.half()
106
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
107
+ return y.bfloat16()
108
+ @staticmethod
109
+ def backward(ctx, gy):
110
+ B = ctx.B
111
+ T = ctx.T
112
+ C = ctx.C
113
+ assert T <= T_MAX
114
+ assert B * C % min(C, 32) == 0
115
+ w, u, k, v, y = ctx.saved_tensors
116
+ gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
117
+ gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
118
+ gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
119
+ gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
120
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
121
+ wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
122
+ else:
123
+ wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
124
+ gw = torch.sum(gw, dim=0)
125
+ gu = torch.sum(gu, dim=0)
126
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
127
+ return (None, None, None, gw, gu, gk, gv)
128
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
129
+ return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
130
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
131
+ return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
132
+
133
+
134
+ def RUN_CUDA(B, T, C, w, u, k, v):
135
+ return WKV.apply(B, T, C, w, u, k, v)
136
+
137
+
138
+ ########################################################################################################
139
+ # RWKV: RWKV Time-mix + RWKV Channel-mix
140
+ ########################################################################################################
141
+
142
+
143
+ class RWKV_TimeMix(MyModule):
144
+ def __init__(self, args, layer_id):
145
+ super().__init__()
146
+ self.args = args
147
+ self.layer_id = layer_id
148
+ self.ctx_len = args.ctx_len
149
+ self.n_embd = args.n_embd
150
+
151
+ with torch.no_grad(): # fancy init
152
+ ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
153
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
154
+ ddd = torch.ones(1, 1, args.n_embd)
155
+ for i in range(args.n_embd):
156
+ ddd[0, 0, i] = i / args.n_embd
157
+
158
+ # fancy time_decay
159
+ decay_speed = torch.ones(args.dim_att)
160
+ for h in range(args.dim_att):
161
+ decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
162
+ self.time_decay = nn.Parameter(decay_speed)
163
+ # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
164
+
165
+ # fancy time_first
166
+ zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
167
+ self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
168
+
169
+ # fancy time_mix
170
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
171
+ self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
172
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
173
+
174
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
175
+ self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
176
+ self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
177
+ self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
178
+ self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
179
+
180
+ if 'a' in os.environ["RWKV_MY_TESTING"]:
181
+ self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
182
+ d_qkv = args.n_embd // 16
183
+ self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
184
+ self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
185
+ self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
186
+ self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
187
+ with torch.no_grad():
188
+ self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
189
+ self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
190
+ self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
191
+
192
+ if 'a' not in os.environ["RWKV_MY_TESTING"]:
193
+ @MyFunction
194
+ def jit_func(self, x):
195
+ xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
196
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
197
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
198
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
199
+ k = self.key(xk)
200
+ v = self.value(xv)
201
+ r = self.receptance(xr)
202
+ sr = torch.sigmoid(r)
203
+ return sr, k, v
204
+
205
+ def forward(self, x):
206
+ B, T, C = x.size() # x = (Batch,Time,Channel)
207
+ sr, k, v = self.jit_func(x)
208
+ rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
209
+ return self.output(rwkv)
210
+
211
+ if 'a' in os.environ["RWKV_MY_TESTING"]:
212
+ @MyFunction
213
+ def QKV(self, q, k, v):
214
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
215
+ att = att.masked_fill(self.att_mask == 0, float('-inf'))
216
+ att = F.softmax(att, dim = -1)
217
+ x = att @ v
218
+ return x
219
+
220
+ @MyFunction
221
+ def jit_funcQKV(self, x):
222
+ xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
223
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
224
+ xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
225
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
226
+ xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
227
+ xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
228
+ xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
229
+ k = self.key(xk)
230
+ v = self.value(xv)
231
+ r = self.receptance(xr)
232
+ sr = torch.sigmoid(r)
233
+ qq = self.qq(xqq)
234
+ kk = self.kk(xkk)
235
+ vv = self.vv(xvv)
236
+ return sr, k, v, qq, kk, vv
237
+
238
+ def forward(self, x):
239
+ B, T, C = x.size() # x = (Batch,Time,Channel)
240
+ sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
241
+ rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
242
+ rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
243
+ return rwkv
244
+
245
+ ########################################################################################################
246
+
247
+ class RWKV_ChannelMix(MyModule):
248
+ def __init__(self, args, layer_id):
249
+ super().__init__()
250
+ self.args = args
251
+ self.layer_id = layer_id
252
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
253
+
254
+ with torch.no_grad(): # fancy init of time_mix
255
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
256
+ ddd = torch.ones(1, 1, args.n_embd)
257
+ for i in range(args.n_embd):
258
+ ddd[0, 0, i] = i / args.n_embd
259
+ self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
260
+ self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
261
+
262
+ self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
263
+ self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
264
+ self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
265
+
266
+ @MyFunction
267
+ def forward(self, x):
268
+ xx = self.time_shift(x)
269
+ xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
270
+ xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
271
+ k = self.key(xk)
272
+ k = torch.square(torch.relu(k))
273
+ kv = self.value(k)
274
+ return torch.sigmoid(self.receptance(xr)) * kv
275
+
276
+ class MishGLU(MyModule):
277
+ def __init__(self, args, layer_id):
278
+ super().__init__()
279
+ self.args = args
280
+ self.layer_id = layer_id
281
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
282
+
283
+ with torch.no_grad():
284
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
285
+
286
+ x = torch.ones(1, 1, args.n_embd)
287
+ for i in range(args.n_embd):
288
+ x[0, 0, i] = i / args.n_embd
289
+
290
+ self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
291
+ self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
292
+ self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
293
+ self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
294
+ self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
295
+
296
+ @MyFunction
297
+ def forward(self, x):
298
+ xx = self.time_shift(x)
299
+ xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
300
+ xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
301
+ a = self.aa(xa)
302
+ b = self.bb(xb)
303
+ return self.value(a * F.mish(b))
304
+
305
+ ########################################################################################################
306
+ # The RWKV Model with our blocks
307
+ ########################################################################################################
308
+
309
+
310
+ class Block(nn.Module):
311
+ def __init__(self, args, layer_id):
312
+ super().__init__()
313
+ self.args = args
314
+ self.layer_id = layer_id
315
+
316
+ self.ln1 = nn.LayerNorm(args.n_embd)
317
+ self.ln2 = nn.LayerNorm(args.n_embd)
318
+
319
+ if self.layer_id == 0:
320
+ self.ln0 = nn.LayerNorm(args.n_embd)
321
+ if args.my_pos_emb > 0:
322
+ self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
323
+ self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
324
+
325
+ if self.layer_id == 0 and self.args.pre_ffn > 0:
326
+ self.ffnPre = RWKV_ChannelMix(args, 0)
327
+ else:
328
+ self.att = RWKV_TimeMix(args, layer_id)
329
+
330
+ if 'g' in os.environ["RWKV_MY_TESTING"]:
331
+ self.ffn = MishGLU(args, layer_id)
332
+ else:
333
+ self.ffn = RWKV_ChannelMix(args, layer_id)
334
+
335
+ if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
336
+ self.tiny_ln = nn.LayerNorm(args.n_embd)
337
+ self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
338
+ self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
339
+ self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
340
+ self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
341
+
342
+ def forward(self, x, x_emb=None):
343
+ args = self.args
344
+ B, T, C = x.size()
345
+ if self.layer_id == 0:
346
+ x = self.ln0(x)
347
+ if args.my_pos_emb > 0:
348
+ pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
349
+ x = x + pos_emb
350
+
351
+ if self.layer_id == 0 and args.pre_ffn > 0:
352
+ x = x + self.ffnPre(self.ln1(x))
353
+ else:
354
+ x = x + self.att(self.ln1(x))
355
+ x = x + self.ffn(self.ln2(x))
356
+
357
+ if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
358
+ xx = self.tiny_ln(x)
359
+ q = self.tiny_q(xx)[:, :T, :]
360
+ k = self.tiny_k(xx)[:, :T, :]
361
+ c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
362
+ c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
363
+ x = x + c @ self.tiny_v(x_emb)
364
+ return x
365
+
366
+
367
+ class L2Wrap(torch.autograd.Function):
368
+ @staticmethod
369
+ def forward(ctx, loss, y):
370
+ ctx.save_for_backward(y)
371
+ return loss
372
+
373
+ @staticmethod
374
+ def backward(ctx, grad_output):
375
+ y = ctx.saved_tensors[0]
376
+ # to encourage the logits to be close to 0
377
+ factor = 1e-4 / (y.shape[0] * y.shape[1])
378
+ maxx, ids = torch.max(y, -1, keepdim=True)
379
+ gy = torch.zeros_like(y)
380
+ gy.scatter_(-1, ids, maxx * factor)
381
+ return (grad_output, gy)
382
+
383
+
384
+ class RWKV(pl.LightningModule):
385
+ def __init__(self, args):
386
+ super().__init__()
387
+ self.args = args
388
+ if not hasattr(args, 'dim_att'):
389
+ args.dim_att = args.n_embd
390
+ if not hasattr(args, 'dim_ffn'):
391
+ args.dim_ffn = args.n_embd * 4
392
+ if not hasattr(args, 'tiny_att_layer'):
393
+ args.tiny_att_layer = -1
394
+ if not hasattr(args, 'tiny_att_dim'):
395
+ args.tiny_att_dim = -1
396
+
397
+ self.emb = nn.Embedding(args.vocab_size, args.n_embd)
398
+
399
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
400
+
401
+ self.ln_out = nn.LayerNorm(args.n_embd)
402
+ self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
403
+
404
+ if args.head_qk > 0:
405
+ self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
406
+ self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
407
+ self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
408
+
409
+ def configure_optimizers(self):
410
+ args = self.args
411
+ if args.layerwise_lr > 0:
412
+ lr_1x = set()
413
+ lr_2x = set()
414
+ lr_3x = set()
415
+ for n, p in self.named_parameters():
416
+ if "time_mix" in n:
417
+ if args.my_pile_stage == 2:
418
+ lr_2x.add(n)
419
+ else:
420
+ lr_1x.add(n)
421
+ elif "time_decay" in n:
422
+ if args.my_pile_stage == 2:
423
+ lr_3x.add(n)
424
+ else:
425
+ lr_2x.add(n)
426
+ elif "time_first" in n:
427
+ lr_3x.add(n)
428
+ else:
429
+ lr_1x.add(n)
430
+ lr_1x = sorted(list(lr_1x))
431
+ lr_2x = sorted(list(lr_2x))
432
+ lr_3x = sorted(list(lr_3x))
433
+ # print('1x', lr_1x)
434
+ # print('2x', lr_2x)
435
+ # print('3x', lr_3x)
436
+ param_dict = {n: p for n, p in self.named_parameters()}
437
+ if args.my_pile_stage == 2:
438
+ optim_groups = [
439
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
440
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
441
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
442
+ ]
443
+ else:
444
+ optim_groups = [
445
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
446
+ {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
447
+ {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
448
+ ]
449
+ else:
450
+ optim_groups = [
451
+ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
452
+ ]
453
+
454
+ if self.deepspeed_offload:
455
+ return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
456
+ return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
457
+ # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
458
+
459
+ @property
460
+ def deepspeed_offload(self) -> bool:
461
+ strategy = self.trainer.strategy
462
+ if isinstance(strategy, DeepSpeedStrategy):
463
+ cfg = strategy.config["zero_optimization"]
464
+ return cfg.get("offload_optimizer") or cfg.get("offload_param")
465
+ return False
466
+
467
+ def forward(self, idx):
468
+ args = self.args
469
+ B, T = idx.size()
470
+ assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
471
+
472
+ x = self.emb(idx)
473
+ x_emb = x
474
+
475
+ if args.tiny_att_dim > 0:
476
+ for block in self.blocks:
477
+ if args.grad_cp == 1:
478
+ x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
479
+ else:
480
+ x = block(x, x_emb)
481
+ else:
482
+ for block in self.blocks:
483
+ if args.grad_cp == 1:
484
+ x = deepspeed.checkpointing.checkpoint(block, x)
485
+ else:
486
+ x = block(x)
487
+
488
+ x = self.ln_out(x)
489
+
490
+ if args.head_qk > 0:
491
+ q = self.head_q(x)[:, :T, :]
492
+ k = self.head_k(x)[:, :T, :]
493
+ c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
494
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
495
+
496
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
497
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size)
498
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
499
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
500
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
501
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
502
+
503
+ x = self.head(x) + c
504
+ else:
505
+ x = self.head(x)
506
+
507
+ return x
508
+
509
+ def training_step(self, batch, batch_idx):
510
+ args = self.args
511
+ if args.my_qa_mask != 1:
512
+ idx, targets = batch
513
+ logits = self(idx)
514
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
515
+ else:
516
+ idx, targets, mask = batch
517
+ mask = mask.view(-1)
518
+ sum_mask = torch.sum(mask).item()
519
+ # if sum_mask == 0:
520
+ # return torch.tensor([0.0], requires_grad=True)
521
+
522
+ logits = self(idx)
523
+ if sum_mask == mask.shape[0]:
524
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
525
+ # print('rank', self.global_rank, 'loss', loss.item())
526
+ else:
527
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
528
+ # loss_raw = loss
529
+ loss = torch.sum(loss * mask) / sum_mask
530
+
531
+ # torch.set_printoptions(threshold=10000)
532
+ # if True: #self.global_rank == 1:
533
+ # tmp = ''
534
+ # sss = 0
535
+ # ccc = 0
536
+ # for i in range(mask.shape[0]):
537
+ # if mask[i] > 0:
538
+ # tmp += str(idx.view(-1)[i].item()) + ','
539
+ # sss += loss_raw.view(-1)[i].float().item()
540
+ # ccc += 1
541
+ # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
542
+
543
+ return L2Wrap.apply(loss, logits)
544
+
545
+ def training_step_end(self, batch_parts):
546
+ all = self.all_gather(batch_parts)
547
+ if self.trainer.is_global_zero:
548
+ self.trainer.my_loss_all = all
549
+
550
+ def generate_init_weight(self):
551
+ print(
552
+ f"""
553
+ ############################################################################
554
+ #
555
+ # Init model weight (slow for large models)...
556
+ #
557
+ ############################################################################
558
+ """
559
+ )
560
+ m = {}
561
+ for n in self.state_dict():
562
+ p = self.state_dict()[n]
563
+ shape = p.shape
564
+
565
+ gain = 1.0
566
+ scale = 1.0
567
+ if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
568
+ m[n] = p
569
+ else:
570
+ if n == "emb.weight":
571
+ scale = -1 * self.args.lr_init
572
+ else:
573
+ if shape[0] > shape[1]:
574
+ gain = math.sqrt(shape[0] / shape[1])
575
+ for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
576
+ if kk in n:
577
+ scale = 0
578
+ if n == "head.weight":
579
+ scale = 0.5
580
+ if "head_k." in n:
581
+ scale = 0.1
582
+ if "head_q." in n:
583
+ scale = 0
584
+
585
+ print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
586
+
587
+ if self.args.accelerator.upper() == "GPU":
588
+ m[n] = torch.empty((shape[0], shape[1]), device="cuda")
589
+ else:
590
+ m[n] = torch.empty((shape[0], shape[1]))
591
+
592
+ if scale == 0:
593
+ nn.init.zeros_(m[n])
594
+ elif scale < 0:
595
+ nn.init.uniform_(m[n], a=scale, b=-scale)
596
+ else:
597
+ nn.init.orthogonal_(m[n], gain=gain * scale)
598
+
599
+ m[n] = m[n].cpu()
600
+ if os.environ["RWKV_FLOAT_MODE"] == "fp16":
601
+ m[n] = m[n].half()
602
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
603
+ m[n] = m[n].bfloat16()
604
+
605
+ # if n == "emb.weight":
606
+ # print(m[n])
607
+
608
+ gc.collect()
609
+ torch.cuda.empty_cache()
610
+ return m
src/model_img.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import numpy as np
6
+ import os, math, gc
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision as vision
11
+ import pytorch_lightning as pl
12
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
13
+ from pytorch_lightning.strategies import DeepSpeedStrategy
14
+ import deepspeed
15
+ from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
16
+ # from pytorch_msssim import MS_SSIM
17
+
18
+ def __nop(ob):
19
+ return ob
20
+ MyModule = torch.jit.ScriptModule
21
+ # MyFunction = __nop
22
+ MyFunction = torch.jit.script_method
23
+
24
+ import clip
25
+ from transformers import CLIPModel
26
+
27
+ class L2pooling(nn.Module):
28
+ def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
29
+ super(L2pooling, self).__init__()
30
+ self.padding = (filter_size - 2) // 2
31
+ self.stride = stride
32
+ self.channels = channels
33
+ a = np.hanning(filter_size)[1:-1]
34
+ g = torch.Tensor(a[:, None] * a[None, :])
35
+ g = g / torch.sum(g)
36
+ self.register_buffer(
37
+ "filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1))
38
+ )
39
+
40
+ def forward(self, input):
41
+ input = input**2
42
+ out = F.conv2d(
43
+ input,
44
+ self.filter,
45
+ stride=self.stride,
46
+ padding=self.padding,
47
+ groups=input.shape[1],
48
+ )
49
+ return (out + 1e-12).sqrt()
50
+
51
+
52
+ class DISTS(torch.nn.Module):
53
+ def __init__(self, load_weights=True):
54
+ super(DISTS, self).__init__()
55
+ vgg_pretrained_features = vision.models.vgg16(
56
+ weights="VGG16_Weights.IMAGENET1K_V1"
57
+ ).features
58
+ self.stage1 = torch.nn.Sequential()
59
+ self.stage2 = torch.nn.Sequential()
60
+ self.stage3 = torch.nn.Sequential()
61
+ self.stage4 = torch.nn.Sequential()
62
+ self.stage5 = torch.nn.Sequential()
63
+ for x in range(0, 4):
64
+ self.stage1.add_module(str(x), vgg_pretrained_features[x])
65
+ self.stage2.add_module(str(4), L2pooling(channels=64))
66
+ for x in range(5, 9):
67
+ self.stage2.add_module(str(x), vgg_pretrained_features[x])
68
+ self.stage3.add_module(str(9), L2pooling(channels=128))
69
+ for x in range(10, 16):
70
+ self.stage3.add_module(str(x), vgg_pretrained_features[x])
71
+ self.stage4.add_module(str(16), L2pooling(channels=256))
72
+ for x in range(17, 23):
73
+ self.stage4.add_module(str(x), vgg_pretrained_features[x])
74
+ self.stage5.add_module(str(23), L2pooling(channels=512))
75
+ for x in range(24, 30):
76
+ self.stage5.add_module(str(x), vgg_pretrained_features[x])
77
+
78
+ self.register_buffer(
79
+ "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
80
+ )
81
+ self.register_buffer(
82
+ "std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
83
+ )
84
+
85
+ self.chns = [3, 64, 128, 256, 512, 512]
86
+ self.register_buffer(
87
+ "alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))
88
+ )
89
+ self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
90
+ self.alpha.data.normal_(0.1, 0.01)
91
+ self.beta.data.normal_(0.1, 0.01)
92
+ weights = torch.load("test/DISTS_weights.pt")
93
+ self.alpha.data = weights["alpha"]
94
+ self.beta.data = weights["beta"]
95
+
96
+ for param in self.parameters():
97
+ param.requires_grad = False
98
+
99
+ def forward_once(self, x):
100
+ h = (x - self.mean) / self.std
101
+ h = self.stage1(h)
102
+ h_relu1_2 = h
103
+ h = self.stage2(h)
104
+ h_relu2_2 = h
105
+ h = self.stage3(h)
106
+ h_relu3_3 = h
107
+ h = self.stage4(h)
108
+ h_relu4_3 = h
109
+ h = self.stage5(h)
110
+ h_relu5_3 = h
111
+ return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
112
+
113
+ def forward(self, x, y, require_grad=False, batch_average=False):
114
+ if require_grad:
115
+ feats0 = self.forward_once(x)
116
+ feats1 = self.forward_once(y)
117
+ else:
118
+ with torch.no_grad():
119
+ feats0 = self.forward_once(x)
120
+ feats1 = self.forward_once(y)
121
+ dist1 = 0
122
+ dist2 = 0
123
+ c1 = 1e-6
124
+ c2 = 1e-6
125
+ w_sum = self.alpha.sum() + self.beta.sum()
126
+ alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
127
+ beta = torch.split(self.beta / w_sum, self.chns, dim=1)
128
+
129
+ for k in range(len(self.chns)):
130
+ x_mean = feats0[k].mean([2, 3], keepdim=True)
131
+ y_mean = feats1[k].mean([2, 3], keepdim=True)
132
+ S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
133
+ dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
134
+
135
+ x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
136
+ y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
137
+ xy_cov = (feats0[k] * feats1[k]).mean(
138
+ [2, 3], keepdim=True
139
+ ) - x_mean * y_mean
140
+ S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
141
+ dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
142
+
143
+ score = 1 - (dist1 + dist2).squeeze()
144
+
145
+ if batch_average:
146
+ return score.mean()
147
+ else:
148
+ return score
149
+
150
+ class ToBinary(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(ctx, x):#, noise_scale):
153
+ # if noise_scale > 0:
154
+ # noise_min = 0.5 - noise_scale / 2
155
+ # noise_max = 0.5 + noise_scale / 2
156
+ # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max))
157
+ # else:
158
+ return torch.floor(x + 0.5) # no need for noise when we have plenty of data
159
+
160
+ @staticmethod
161
+ def backward(ctx, grad_output):
162
+ return grad_output.clone()#, None
163
+
164
+ ########################################################################################################
165
+
166
+ class R_ENCODER(MyModule):
167
+ def __init__(self, args):
168
+ super().__init__()
169
+ self.args = args
170
+ dd = 8
171
+ self.Bxx = nn.BatchNorm2d(dd*64)
172
+
173
+ self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1)
174
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
175
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
176
+
177
+ self.B00 = nn.BatchNorm2d(dd*4)
178
+ self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
179
+ self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
180
+ self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
181
+ self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
182
+
183
+ self.B10 = nn.BatchNorm2d(dd*16)
184
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
185
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
186
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
187
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
188
+
189
+ self.B20 = nn.BatchNorm2d(dd*64)
190
+ self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
191
+ self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
192
+ self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
193
+ self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
194
+ # self.B21 = nn.BatchNorm2d(dd*64)
195
+ # self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
196
+ # self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
197
+ # self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
198
+ # self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
199
+
200
+ self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1)
201
+
202
+ @MyFunction
203
+ def forward(self, img):
204
+ ACT = F.mish
205
+
206
+ x = self.CIN(img)
207
+ xx = self.Bxx(F.pixel_unshuffle(x, 8))
208
+ x = x + self.Cx1(ACT(self.Cx0(x)))
209
+
210
+ x = F.pixel_unshuffle(x, 2)
211
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
212
+ x = x + self.C03(ACT(self.C02(x)))
213
+
214
+ x = F.pixel_unshuffle(x, 2)
215
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
216
+ x = x + self.C13(ACT(self.C12(x)))
217
+
218
+ x = F.pixel_unshuffle(x, 2)
219
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
220
+ x = x + self.C23(ACT(self.C22(x)))
221
+ # x = x + self.C25(ACT(self.C24(ACT(self.B21(x)))))
222
+ # x = x + self.C27(ACT(self.C26(x)))
223
+
224
+ x = self.COUT(x + xx)
225
+ return torch.sigmoid(x)
226
+
227
+ ########################################################################################################
228
+
229
+ class R_DECODER(MyModule):
230
+ def __init__(self, args):
231
+ super().__init__()
232
+ self.args = args
233
+ dd = 8
234
+ self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1)
235
+
236
+ self.B00 = nn.BatchNorm2d(dd*64)
237
+ self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
238
+ self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
239
+ self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
240
+ self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
241
+ # self.B01 = nn.BatchNorm2d(dd*64)
242
+ # self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
243
+ # self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
244
+ # self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1)
245
+ # self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1)
246
+
247
+ self.B10 = nn.BatchNorm2d(dd*16)
248
+ self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
249
+ self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
250
+ self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1)
251
+ self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1)
252
+
253
+ self.B20 = nn.BatchNorm2d(dd*4)
254
+ self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
255
+ self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
256
+ self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1)
257
+ self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1)
258
+
259
+ self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1)
260
+ self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1)
261
+ self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1)
262
+
263
+ @MyFunction
264
+ def forward(self, code):
265
+ ACT = F.mish
266
+ x = self.CIN(code)
267
+
268
+ x = x + self.C01(ACT(self.C00(ACT(self.B00(x)))))
269
+ x = x + self.C03(ACT(self.C02(x)))
270
+ # x = x + self.C05(ACT(self.C04(ACT(self.B01(x)))))
271
+ # x = x + self.C07(ACT(self.C06(x)))
272
+ x = F.pixel_shuffle(x, 2)
273
+
274
+ x = x + self.C11(ACT(self.C10(ACT(self.B10(x)))))
275
+ x = x + self.C13(ACT(self.C12(x)))
276
+ x = F.pixel_shuffle(x, 2)
277
+
278
+ x = x + self.C21(ACT(self.C20(ACT(self.B20(x)))))
279
+ x = x + self.C23(ACT(self.C22(x)))
280
+ x = F.pixel_shuffle(x, 2)
281
+
282
+ x = x + self.Cx1(ACT(self.Cx0(x)))
283
+ x = self.COUT(x)
284
+
285
+ return torch.sigmoid(x)
286
+
287
+ ########################################################################################################`
288
+
289
+ def cosine_loss(x, y):
290
+ x = F.normalize(x, dim=-1)
291
+ y = F.normalize(y, dim=-1)
292
+ return 1 - torch.einsum('ij,ij->i',[x,y])
293
+
294
+ class RWKV_IMG(pl.LightningModule):
295
+ def __init__(self, args):
296
+ super().__init__()
297
+ self.args = args
298
+
299
+ self.encoder = R_ENCODER(args)
300
+ self.decoder = R_DECODER(args)
301
+
302
+ self.clip_model = None
303
+ clip_name = args.my_img_clip
304
+ if clip_name == 'B32':
305
+ clip_name = 'ViT-B/32'
306
+ elif clip_name == 'B16':
307
+ clip_name = 'ViT-B/16'
308
+ elif clip_name == 'L14':
309
+ clip_name = 'ViT-L/14'
310
+ elif clip_name == 'OB32':
311
+ clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
312
+ self.clip_model = CLIPModel.from_pretrained(clip_name)
313
+ self.clip_model.encode_image = self.clip_model.get_image_features
314
+ if self.clip_model == None:
315
+ self.clip_model, _ = clip.load(clip_name, jit = True)
316
+ self.register_buffer(
317
+ "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
318
+ )
319
+ self.register_buffer(
320
+ "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
321
+ )
322
+
323
+ for n, p in self.named_parameters():
324
+ if 'clip_model' in n:
325
+ p.requires_grad = False
326
+
327
+ self.loss_dists = DISTS()
328
+ # self.loss_ssim = MS_SSIM(data_range=1, size_average=True, channel=3)
329
+
330
+ def configure_optimizers(self):
331
+ args = self.args
332
+ optim_groups = [
333
+ {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
334
+ ]
335
+ if self.deepspeed_offload:
336
+ return DeepSpeedCPUAdam(
337
+ optim_groups,
338
+ lr=self.args.lr_init,
339
+ betas=self.args.betas,
340
+ eps=self.args.adam_eps,
341
+ bias_correction=True,
342
+ adamw_mode=False,
343
+ weight_decay=0,
344
+ amsgrad=False,
345
+ )
346
+ return FusedAdam(
347
+ optim_groups,
348
+ lr=self.args.lr_init,
349
+ betas=self.args.betas,
350
+ eps=self.args.adam_eps,
351
+ bias_correction=True,
352
+ adam_w_mode=False,
353
+ weight_decay=0,
354
+ amsgrad=False,
355
+ )
356
+ # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
357
+
358
+ @property
359
+ def deepspeed_offload(self) -> bool:
360
+ strategy = self.trainer.strategy
361
+ if isinstance(strategy, DeepSpeedStrategy):
362
+ config = strategy.config["zero_optimization"]
363
+ return config.get("offload_optimizer") or config.get("offload_param")
364
+ return False
365
+
366
+ def forward(self, img):
367
+ z = self.encoder(img)
368
+ z = ToBinary.apply(z)#, self.args.my_img_noise_scale)
369
+ out = self.decoder(z)
370
+ return out
371
+
372
+ def training_step(self, batch, batch_idx):
373
+ args = self.args
374
+ img, txt = batch
375
+ out = self(img)
376
+ if self.trainer.is_global_zero:
377
+ if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0:
378
+ img_dir = f"test/image_model/{args.run_name}"
379
+ if not os.path.exists(img_dir):
380
+ os.makedirs(img_dir)
381
+ vision.utils.save_image(
382
+ img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0
383
+ )
384
+ vision.utils.save_image(
385
+ out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0
386
+ )
387
+
388
+ # loss_ssim = 1 - self.loss_ssim(out, img)
389
+ loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True)
390
+
391
+ iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std)
392
+ ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std)
393
+ loss_clip = torch.mean(cosine_loss(iii, ooo))
394
+
395
+ if args.my_img_l1_scale > 0:
396
+ loss_l1 = F.l1_loss(out, img)
397
+ return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale
398
+ else:
399
+ return loss_dists + loss_clip * args.my_img_clip_scale
400
+
401
+ def training_step_end(self, batch_parts):
402
+ all = self.all_gather(batch_parts)
403
+ if self.trainer.is_global_zero:
404
+ self.trainer.my_loss_all = all
405
+
406
+ def generate_init_weight(self):
407
+ print(
408
+ f"""
409
+ ############################################################################
410
+ #
411
+ # Init model weight (slow for large models)...
412
+ #
413
+ ############################################################################
414
+ """
415
+ )
416
+ m = {}
417
+ for n in self.state_dict():
418
+ scale = 1
419
+ p = self.state_dict()[n]
420
+ shape = p.shape
421
+ ss = n.split('.')
422
+
423
+ # if ss[0] in ['encoder', 'decoder']:
424
+ # if ss[2] == 'bias':
425
+ # scale = 0
426
+ # # elif n == 'encoder.CIN.weight':
427
+ # # nn.init.dirac_(p)
428
+ # else:
429
+ # try:
430
+ # if ss[1][0] == 'C' and (int(ss[1][2]) % 2 == 1):
431
+ # scale = 0
432
+ # except:
433
+ # pass
434
+ # m[n] = p * scale
435
+
436
+ m[n] = p
437
+
438
+ m[n] = m[n].cpu()
439
+ if os.environ["RWKV_FLOAT_MODE"] == "fp16":
440
+ m[n] = m[n].half()
441
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
442
+ m[n] = m[n].bfloat16()
443
+
444
+ gc.collect()
445
+ torch.cuda.empty_cache()
446
+ return m
src/model_run.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ import types
6
+ import torch
7
+ import math, os, gc
8
+ from torch.nn import functional as F
9
+ import torch.nn as nn
10
+ from typing import List, Dict
11
+
12
+ MyModule = nn.Module
13
+ def __nop(ob):
14
+ return ob
15
+ MyFunction = __nop
16
+
17
+ # # try torchdynamo
18
+ # import torchdynamo
19
+ # MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output
20
+
21
+ # try torch jit --> faster for fp32, slower for fp16 (why?)
22
+ if os.environ["RWKV_JIT_ON"] == "1":
23
+ MyModule = torch.jit.ScriptModule
24
+ MyFunction = torch.jit.script_method
25
+
26
+ RWKV_HEAD_QK_DIM = 0
27
+
28
+ DEBUG_TIME = False # True False - show trained time-coeffs
29
+
30
+ RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer
31
+
32
+ ############################################################################################################
33
+
34
+ class RWKV_RNN(MyModule):
35
+ def __init__(self, args):
36
+ super().__init__()
37
+
38
+ self.args = args
39
+ self.FLOAT_MODE = args.FLOAT_MODE
40
+ self.RUN_DEVICE = args.RUN_DEVICE
41
+
42
+ with torch.no_grad():
43
+ w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
44
+ # refine weights and send to correct device
45
+ keys = list(w.keys())
46
+ if 'pos_emb_x' in keys:
47
+ w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:]
48
+ keys = list(w.keys())
49
+ print_need_newline = False
50
+ for x in keys:
51
+ block_id = 0
52
+ if 'blocks.' in x:
53
+ block_id = int(x.split('.')[1])
54
+ if 'att.output.weight' in x:
55
+ w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
56
+ if 'ffn.value.weight' in x:
57
+ w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER))
58
+
59
+ if '.time_' in x:
60
+ w[x] = w[x].squeeze()
61
+ if DEBUG_TIME:
62
+ print(x, w[x].numpy())
63
+ if '.time_decay' in x:
64
+ w[x] = w[x].float()
65
+ w[x] = -torch.exp(w[x])
66
+ elif '.time_first' in x:
67
+ w[x] = w[x].float()
68
+ else:
69
+ if self.FLOAT_MODE == "fp32":
70
+ w[x] = w[x].float()
71
+ elif self.FLOAT_MODE == "bf16":
72
+ w[x] = w[x].bfloat16()
73
+ elif self.FLOAT_MODE == "fp16":
74
+ w[x] = w[x].half()
75
+
76
+ w[x].requires_grad = False
77
+ if args.RUN_DEVICE == 'cuda' and x != 'emb.weight':
78
+ w[x] = w[x].cuda()
79
+
80
+ if ('blocks.' not in x) or ('blocks.0.' in x):
81
+ if print_need_newline:
82
+ print_need_newline = False
83
+ else:
84
+ print_need_newline = True
85
+
86
+ # store weights in self.w
87
+ keys = list(w.keys())
88
+ self.w = types.SimpleNamespace()
89
+ for x in keys:
90
+ xx = x.split('.')
91
+ here = self.w
92
+ for i in range(len(xx)):
93
+ if xx[i].isdigit():
94
+ ii = int(xx[i])
95
+ if ii not in here:
96
+ here[ii] = types.SimpleNamespace()
97
+ here = here[ii]
98
+ else:
99
+ if i == len(xx) - 1:
100
+ setattr(here, xx[i], w[x])
101
+ elif not hasattr(here, xx[i]):
102
+ if xx[i+1].isdigit():
103
+ setattr(here, xx[i], {})
104
+ else:
105
+ setattr(here, xx[i], types.SimpleNamespace())
106
+ here = getattr(here, xx[i])
107
+
108
+ self.eval()
109
+ gc.collect()
110
+ torch.cuda.empty_cache()
111
+
112
+ def LN(self, x, w):
113
+ return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
114
+
115
+ # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp
116
+
117
+ @MyFunction
118
+ def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
119
+ if self.FLOAT_MODE == "bf16":
120
+ xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k)
121
+ xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r)
122
+ state[5*i+0] = x.float()
123
+ elif self.FLOAT_MODE == "fp16":
124
+ xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k)
125
+ xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r)
126
+ state[5*i+0] = x.float()
127
+ else:
128
+ xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
129
+ xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
130
+ state[5*i+0] = x
131
+
132
+ r = torch.sigmoid(rw @ xr)
133
+ k = torch.square(torch.relu(kw @ xk))
134
+ kv = vw @ k
135
+
136
+ return r * kv
137
+
138
+ @MyFunction
139
+ def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
140
+ if self.FLOAT_MODE == "bf16":
141
+ xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k)
142
+ xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v)
143
+ xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r)
144
+ state[5*i+1] = x.float()
145
+ elif self.FLOAT_MODE == "fp16":
146
+ xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k)
147
+ xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v)
148
+ xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r)
149
+ state[5*i+1] = x.float()
150
+ else:
151
+ xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
152
+ xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
153
+ xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
154
+ state[5*i+1] = x
155
+
156
+ r = torch.sigmoid(rw @ xr)
157
+ k = kw @ xk
158
+ v = vw @ xv
159
+
160
+ if '16' in self.FLOAT_MODE:
161
+ kk = k.float()
162
+ vv = v.float()
163
+ else:
164
+ kk = k
165
+ vv = v
166
+ aa = state[5*i+2]
167
+ bb = state[5*i+3]
168
+ pp = state[5*i+4]
169
+ ww = time_first + kk
170
+ p = torch.maximum(pp, ww)
171
+ e1 = torch.exp(pp - p)
172
+ e2 = torch.exp(ww - p)
173
+ a = e1 * aa + e2 * vv
174
+ b = e1 * bb + e2
175
+ ww = pp + time_decay
176
+ p = torch.maximum(ww, kk)
177
+ e1 = torch.exp(ww - p)
178
+ e2 = torch.exp(kk - p)
179
+ state[5*i+2] = e1 * aa + e2 * vv
180
+ state[5*i+3] = e1 * bb + e2
181
+ state[5*i+4] = p
182
+ if self.FLOAT_MODE == "bf16":
183
+ wkv = (a / b).type(torch.bfloat16)
184
+ elif self.FLOAT_MODE == "fp16":
185
+ wkv = (a / b).half()
186
+ else:
187
+ wkv = a / b
188
+
189
+ return ow @ (r * wkv)
190
+
191
+ def forward(self, ctx, state, preprocess_only = False):
192
+ with torch.no_grad():
193
+ w = self.w
194
+ args = self.args
195
+
196
+ x = w.emb.weight[ctx[-1]]
197
+ if self.RUN_DEVICE == 'cuda':
198
+ x = x.cuda()
199
+ try:
200
+ pos_emb = w.pos_emb[len(ctx)-1]
201
+ x = x + pos_emb
202
+ except:
203
+ pass
204
+
205
+ if state == None:
206
+ state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE)
207
+ for i in range(args.n_layer):
208
+ state[5*i+4] -= 1e30
209
+
210
+ for i in range(args.n_layer):
211
+ if i == 0:
212
+ x = self.LN(x, w.blocks[i].ln0)
213
+
214
+ ww = w.blocks[i].att
215
+ x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i,
216
+ ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay,
217
+ ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight)
218
+
219
+ ww = w.blocks[i].ffn
220
+ x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i,
221
+ ww.time_mix_k, ww.time_mix_r,
222
+ ww.key.weight, ww.value.weight, ww.receptance.weight)
223
+
224
+ if (i+1) % RWKV_RESCALE_LAYER == 0:
225
+ x = x / 2
226
+
227
+ if preprocess_only:
228
+ return state
229
+
230
+ x = self.LN(x, w.ln_out)
231
+ x = w.head.weight @ x
232
+
233
+ return x.float(), state
src/trainer.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, time, datetime, subprocess
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
6
+
7
+ def my_save(dd, ff):
8
+ if '14b-run1' not in ff:
9
+ torch.save(dd, ff)
10
+ else:
11
+ fn = ff.split('/')[-1]
12
+ fff = '/dev/shm/' + fn
13
+ torch.save(dd, fff)
14
+ subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
15
+
16
+ class train_callback(pl.Callback):
17
+ def __init__(self, args):
18
+ super().__init__()
19
+ self.args = args
20
+
21
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
22
+ args = self.args
23
+ # if args.cuda_cleanup > 0:
24
+ # torch.cuda.empty_cache()
25
+ real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
26
+
27
+ # LR schedule
28
+ w_step = args.warmup_steps
29
+ if args.lr_final == args.lr_init or args.epoch_count == 0:
30
+ lr = args.lr_init
31
+ if trainer.global_step < w_step:
32
+ lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
33
+ else:
34
+ decay_step = real_step - args.my_pile_edecay * args.epoch_steps
35
+ decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
36
+ progress = (decay_step - w_step + 1) / (decay_total - w_step)
37
+ progress = min(1, max(0, progress))
38
+
39
+ if args.lr_final == 0 or args.lr_init == 0: # linear decay
40
+ lr = args.lr_init + (args.lr_final - args.lr_init) * progress
41
+ else: # exp decay
42
+ lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
43
+
44
+ if trainer.global_step < w_step:
45
+ lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
46
+ # if trainer.is_global_zero:
47
+ # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
48
+
49
+ for param_group in trainer.optimizers[0].param_groups:
50
+ if args.layerwise_lr > 0:
51
+ param_group["lr"] = lr * param_group["my_lr_scale"]
52
+ # print(param_group["lr"], param_group["my_lr_scale"])
53
+ else:
54
+ param_group["lr"] = lr
55
+
56
+ trainer.my_lr = lr
57
+ # rank_zero_info(f"{real_step} {lr}")
58
+
59
+ if trainer.global_step == 0:
60
+ if trainer.is_global_zero: # logging
61
+ trainer.my_loss_sum = 0
62
+ trainer.my_loss_count = 0
63
+ trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
64
+ trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
65
+ try:
66
+ print(f"\n{trainer.strategy.config}\n")
67
+ trainer.my_log.write(f"{trainer.strategy.config}\n")
68
+ except:
69
+ pass
70
+ trainer.my_log.flush()
71
+ if len(args.wandb) > 0:
72
+ print("Login to wandb...")
73
+ import wandb
74
+ wandb.init(
75
+ project=args.wandb,
76
+ name=args.run_name + " " + args.my_timestamp,
77
+ config=args,
78
+ save_code=False,
79
+ )
80
+ trainer.my_wandb = wandb
81
+
82
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
83
+ args = self.args
84
+ if trainer.is_global_zero: # logging
85
+ t_now = time.time_ns()
86
+ token_per_step = args.ctx_len * args.real_bsz
87
+ real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
88
+ kt_s = 0
89
+ try:
90
+ t_cost = (t_now - trainer.my_time_ns) / 1e9
91
+ kt_s = token_per_step / t_cost / 1000
92
+ self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
93
+ self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
94
+ except:
95
+ pass
96
+ trainer.my_time_ns = t_now
97
+ trainer.my_loss = trainer.my_loss_all.float().mean().item()
98
+ trainer.my_loss_sum += trainer.my_loss
99
+ trainer.my_loss_count += 1
100
+ trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
101
+ self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
102
+ self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
103
+ # self.log("s", real_step, prog_bar=True, on_step=True)
104
+
105
+ if len(args.wandb) > 0:
106
+ lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
107
+ if kt_s > 0:
108
+ lll["kt/s"] = kt_s
109
+ trainer.my_wandb.log(lll, step=int(real_step))
110
+ if args.magic_prime > 0:
111
+ expand_factor = 2 if args.my_qa_mask > 0 else 1
112
+ if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1 + int(args.my_random_steps):
113
+ to_save_dict = pl_module.state_dict()
114
+ my_save(
115
+ to_save_dict,
116
+ f"{args.proj_dir}/rwkv-final.pth",
117
+ )
118
+
119
+
120
+ def on_train_epoch_start(self, trainer, pl_module):
121
+ args = self.args
122
+ dataset = trainer.train_dataloader.dataset.datasets
123
+ assert "MyDataset" in str(dataset)
124
+ dataset.global_rank = trainer.global_rank
125
+ dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
126
+ dataset.world_size = trainer.world_size
127
+ # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
128
+
129
+ def on_train_epoch_end(self, trainer, pl_module):
130
+ args = self.args
131
+ if trainer.is_global_zero: # logging & save state_dict
132
+ if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
133
+ if args.data_type == 'wds_img':
134
+ raw_dict = pl_module.state_dict()
135
+ to_save_dict = {}
136
+ for k in raw_dict:
137
+ if k.startswith('encoder.') or k.startswith('decoder.'):
138
+ to_save_dict[k] = raw_dict[k]
139
+ else:
140
+ to_save_dict = pl_module.state_dict()
141
+ try:
142
+ my_save(
143
+ to_save_dict,
144
+ f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
145
+ )
146
+ except Exception as e:
147
+ print('Error\n\n', e, '\n\n')
148
+ trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
149
+ trainer.my_log.flush()
150
+
151
+ trainer.my_loss_sum = 0
152
+ trainer.my_loss_count = 0
153
+
154
+
155
+ @rank_zero_only
156
+ def generate_init_weight(model, init_weight_name):
157
+ mm = model.generate_init_weight()
158
+
159
+ if model.args.my_pile_stage == 1:
160
+ if len(model.args.load_model) > 0:
161
+ print(f"Combine weights from {model.args.load_model}...")
162
+ load_dict = torch.load(model.args.load_model, map_location="cpu")
163
+ for k in load_dict:
164
+ assert k in mm
165
+ src = load_dict[k]
166
+ try:
167
+ mm[k] = src.reshape(mm[k].shape)
168
+ except:
169
+ tmp = mm[k].squeeze().clone()
170
+ print(k, src.shape, '-->', mm[k].shape)
171
+ ss = src.shape[0]
172
+ dd = tmp.shape[0]
173
+ for i in range(dd):
174
+ pos = i / dd * ss
175
+ if pos >= ss - 1:
176
+ tmp[i] = src[ss-1]
177
+ else:
178
+ p0 = int(math.floor(pos))
179
+ ii = pos - p0
180
+ tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
181
+ mm[k] = tmp.reshape(mm[k].shape)
182
+ sss = src.squeeze().float().cpu().numpy()
183
+ print(sss[:10], '...', sss[-10:])
184
+ mmm = mm[k].squeeze().float().cpu().numpy()
185
+ print(mmm[:10], '...', mmm[-10:])
186
+
187
+ print(f"Save to {init_weight_name}...")
188
+ torch.save(mm, init_weight_name)
189
+
190
+ if model.args.my_pile_stage == 1:
191
+ print("Done. Now go for stage 2.")
192
+ exit(0)
src/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, time, random, os
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ time_slot = {}
7
+ time_ref = time.time_ns()
8
+
9
+ def record_time(name):
10
+ if name not in time_slot:
11
+ time_slot[name] = 1e20
12
+ tt = (time.time_ns() - time_ref) / 1e9
13
+ if tt < time_slot[name]:
14
+ time_slot[name] = tt
15
+
16
+ class TOKENIZER():
17
+ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
18
+ if 'list' in str(type(WORD_NAME)):
19
+ self.charMode = False
20
+ if WORD_NAME[0] == WORD_NAME[1]:
21
+ from transformers import PreTrainedTokenizerFast
22
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
23
+ else:
24
+ from transformers import GPT2TokenizerFast
25
+ self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
26
+ self.vocab_size = len(self.tokenizer)
27
+ else:
28
+ self.charMode = True
29
+ with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
30
+ self.word_table = json.load(result_file)
31
+
32
+ self.vocab_size = len(self.word_table)
33
+
34
+ self.stoi = {v: int(k) for k, v in self.word_table.items()}
35
+ self.itos = {int(k): v for k, v in self.word_table.items()}
36
+
37
+ self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
38
+
39
+ def refine_context(self, context):
40
+ context = context.strip().split('\n')
41
+ for c in range(len(context)):
42
+ context[c] = context[c].strip().strip('\u3000').strip('\r')
43
+ context = list(filter(lambda c: c != '', context))
44
+ context = '\n' + ('\n'.join(context)).strip()
45
+ if context == '':
46
+ context = '\n'
47
+ return context
48
+
49
+ def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
50
+ # out[self.UNKNOWN_CHAR] = -float('Inf')
51
+ lastChar = int(x[-1])
52
+
53
+ probs = F.softmax(out, dim=-1)
54
+
55
+ if self.charMode:
56
+ if self.itos[lastChar] == '\n':
57
+ top_p = top_p_newline
58
+ else:
59
+ top_p = top_p_usual
60
+ else:
61
+ top_p = top_p_usual
62
+
63
+ if os.environ["RWKV_RUN_DEVICE"] == "cpu":
64
+ probs = probs.numpy()
65
+ sorted_probs = np.sort(probs)[::-1]
66
+ cumulative_probs = np.cumsum(sorted_probs)
67
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
68
+ probs[probs < cutoff] = 0
69
+ if temperature != 1.0:
70
+ probs = probs.pow(1.0 / temperature)
71
+ probs = probs / np.sum(probs)
72
+ out = np.random.choice(a=len(probs), p=probs)
73
+ return out
74
+ else:
75
+ sorted_probs = torch.sort(probs, descending=True)[0]
76
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
77
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
78
+ probs[probs < cutoff] = 0
79
+ if temperature != 1.0:
80
+ probs = probs.pow(1.0 / temperature)
81
+ out = torch.multinomial(probs, num_samples=1)[0]
82
+ return out
83
+
84
+ def MaybeIsPrime(number):
85
+ if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
86
+ return True
87
+ else:
88
+ return False
89
+
90
+
91
+ def FermatPrimalityTest(number):
92
+ if number > 1:
93
+ for time in range(3):
94
+ randomNumber = random.randint(2, number) - 1
95
+ if pow(randomNumber, number - 1, number) != 1:
96
+ return False
97
+ return True
98
+ else:
99
+ return False
100
+
101
+
102
+ def MillerRabinPrimalityTest(number):
103
+ if number == 2:
104
+ return True
105
+ elif number == 1 or number % 2 == 0:
106
+ return False
107
+ oddPartOfNumber = number - 1
108
+ timesTwoDividNumber = 0
109
+ while oddPartOfNumber % 2 == 0:
110
+ oddPartOfNumber = oddPartOfNumber // 2
111
+ timesTwoDividNumber = timesTwoDividNumber + 1
112
+
113
+ for time in range(3):
114
+ while True:
115
+ randomNumber = random.randint(2, number) - 1
116
+ if randomNumber != 0 and randomNumber != 1:
117
+ break
118
+
119
+ randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
120
+
121
+ if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
122
+ iterationNumber = 1
123
+
124
+ while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
125
+ randomNumberWithPower = pow(randomNumberWithPower, 2, number)
126
+ iterationNumber = iterationNumber + 1
127
+ if randomNumberWithPower != (number - 1):
128
+ return False
129
+
130
+ return True
train.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ if __name__ == "__main__":
6
+ from argparse import ArgumentParser
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
9
+
10
+ rank_zero_info("########## work in progress ##########")
11
+
12
+ ########################################################################################################
13
+ #
14
+ # example: train a simple L12-D768 RWKV on dummy data
15
+ #
16
+ # python train.py --load_model "" --wandb "" --proj_dir "out" \
17
+ # --data_file "" --data_type "dummy" --vocab_size 0 \
18
+ # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
19
+ # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
20
+ # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
21
+ # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
22
+
23
+ # example: train a simple L6-D512 RWKV from scratch on enwik8
24
+ #
25
+ # python train.py --load_model "" --wandb "" --proj_dir "out" \
26
+ # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
27
+ # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
28
+ # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
29
+ # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
30
+ # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
31
+
32
+ # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
33
+ #
34
+ # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
35
+ # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
36
+ # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
37
+ # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
38
+ # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
39
+ # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
40
+
41
+ # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
42
+ #
43
+ # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
44
+ # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
45
+ # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
46
+ # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
47
+ # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
48
+ # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
49
+
50
+ parser = ArgumentParser()
51
+
52
+ parser.add_argument("--load_model", default="", type=str) # full path, with .pth
53
+ parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
54
+ parser.add_argument("--proj_dir", default="out", type=str)
55
+ parser.add_argument("--random_seed", default="-1", type=int)
56
+
57
+ parser.add_argument("--data_file", default="", type=str)
58
+ parser.add_argument("--data_type", default="utf-8", type=str)
59
+ parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
60
+
61
+ parser.add_argument("--ctx_len", default=1024, type=int)
62
+ parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
63
+ parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
64
+ parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
65
+ parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
66
+
67
+ parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
68
+ parser.add_argument("--n_layer", default=6, type=int)
69
+ parser.add_argument("--n_embd", default=512, type=int)
70
+ parser.add_argument("--dim_att", default=0, type=int)
71
+ parser.add_argument("--dim_ffn", default=0, type=int)
72
+ parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
73
+ parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
74
+ parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
75
+ parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
76
+
77
+ parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
78
+ parser.add_argument("--lr_final", default=1e-5, type=float)
79
+ parser.add_argument("--warmup_steps", default=-1, type=int) # try 50 if you load a model
80
+ parser.add_argument("--beta1", default=0.9, type=float)
81
+ parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
82
+ parser.add_argument("--adam_eps", default=1e-8, type=float)
83
+ parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
84
+
85
+ parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
86
+ parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
87
+ parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
88
+ parser.add_argument("--my_pile_edecay", default=0, type=int)
89
+ parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
90
+ parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
91
+ # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
92
+
93
+ parser.add_argument("--my_img_version", default=0, type=str)
94
+ parser.add_argument("--my_img_size", default=0, type=int)
95
+ parser.add_argument("--my_img_bit", default=0, type=int)
96
+ parser.add_argument("--my_img_clip", default='x', type=str)
97
+ parser.add_argument("--my_img_clip_scale", default=1, type=float)
98
+ parser.add_argument("--my_img_l1_scale", default=0, type=float)
99
+ parser.add_argument("--my_img_encoder", default='x', type=str)
100
+ # parser.add_argument("--my_img_noise_scale", default=0, type=float)
101
+ parser.add_argument("--my_sample_len", default=0, type=int)
102
+ parser.add_argument("--my_ffn_shift", default=1, type=int)
103
+ parser.add_argument("--my_att_shift", default=1, type=int)
104
+ parser.add_argument("--my_pos_emb", default=0, type=int)
105
+ parser.add_argument("--load_partial", default=0, type=int)
106
+ parser.add_argument("--magic_prime", default=0, type=int)
107
+ parser.add_argument("--my_qa_mask", default=0, type=int)
108
+ parser.add_argument("--my_random_steps", default=0, type=int)
109
+ parser.add_argument("--my_testing", default='', type=str)
110
+
111
+ parser = Trainer.add_argparse_args(parser)
112
+ args = parser.parse_args()
113
+
114
+ ########################################################################################################
115
+
116
+ import os, warnings, math, datetime, sys, time, importlib
117
+ import numpy as np
118
+ import torch
119
+ from torch.utils.data import DataLoader
120
+ if "deepspeed" in args.strategy:
121
+ import deepspeed
122
+ import pytorch_lightning as pl
123
+ from pytorch_lightning import seed_everything
124
+
125
+ if args.random_seed >= 0:
126
+ print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
127
+ seed_everything(args.random_seed)
128
+
129
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
130
+ warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
131
+ warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
132
+ # os.environ["WDS_SHOW_SEED"] = "1"
133
+
134
+ args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
135
+ args.enable_checkpointing = False
136
+ args.replace_sampler_ddp = False
137
+ args.logger = False
138
+ args.gradient_clip_val = 1.0
139
+ args.num_sanity_val_steps = 0
140
+ args.check_val_every_n_epoch = int(1e20)
141
+ args.log_every_n_steps = int(1e20)
142
+ args.max_epochs = -1 # continue forever
143
+ args.betas = (args.beta1, args.beta2)
144
+ args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
145
+ os.environ["RWKV_T_MAX"] = str(args.ctx_len)
146
+ os.environ["RWKV_MY_TESTING"] = args.my_testing
147
+ if args.dim_att <= 0:
148
+ args.dim_att = args.n_embd
149
+ if args.dim_ffn <= 0:
150
+ args.dim_ffn = args.n_embd * 4
151
+
152
+ if args.data_type == "wds_img":
153
+ args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
154
+ args.proj_dir = f"{args.proj_dir}-{args.run_name}"
155
+ else:
156
+ args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
157
+ if not os.path.exists(args.proj_dir):
158
+ os.makedirs(args.proj_dir)
159
+
160
+ if args.my_pile_stage > 0:
161
+ magic_prime_bak = args.magic_prime
162
+
163
+ if args.my_pile_version == 1:
164
+ if args.ctx_len == 1024:
165
+ args.magic_prime = 324331313
166
+ args.epoch_count = 8043
167
+ elif args.ctx_len == 2048:
168
+ args.magic_prime = 162165671
169
+ args.epoch_count = 4021
170
+ elif args.ctx_len == 4096:
171
+ args.magic_prime = 81082817
172
+ args.epoch_count = 2010
173
+ elif args.ctx_len == 8192:
174
+ args.magic_prime = 40541399
175
+ args.epoch_count = 1005
176
+ else:
177
+ if args.ctx_len == 1024:
178
+ args.magic_prime = 1670239709
179
+ args.epoch_count = 41423
180
+ elif args.ctx_len == 2048:
181
+ args.magic_prime = 835119767
182
+ args.epoch_count = 20711
183
+ elif args.ctx_len == 4096:
184
+ args.magic_prime = 417559889
185
+ args.epoch_count = 10355
186
+ elif args.ctx_len == 6144:
187
+ args.magic_prime = 278373239
188
+ args.epoch_count = 6903
189
+ elif args.ctx_len == 8192:
190
+ args.magic_prime = 208779911
191
+ args.epoch_count = 5177
192
+ if args.my_pile_shift < 0:
193
+ args.my_pile_shift = 0
194
+
195
+ if magic_prime_bak > 0:
196
+ args.magic_prime = magic_prime_bak
197
+
198
+ args.epoch_steps = 40320 // args.real_bsz
199
+ assert args.epoch_steps * args.real_bsz == 40320
200
+ if args.my_pile_stage == 2:
201
+ assert args.lr_final == args.lr_init
202
+ if args.my_pile_stage >= 2: # find latest saved model
203
+ list_p = []
204
+ for p in os.listdir(args.proj_dir):
205
+ if p.startswith("rwkv") and p.endswith(".pth"):
206
+ p = ((p.split("-"))[1].split("."))[0]
207
+ if p == "init":
208
+ p = -1
209
+ else:
210
+ p = int(p)
211
+ list_p += [p]
212
+ list_p.sort()
213
+ max_p = list_p[-1]
214
+ if len(list_p) > 1:
215
+ args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
216
+ if max_p == -1:
217
+ args.load_model = f"{args.proj_dir}/rwkv-init.pth"
218
+ else:
219
+ args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
220
+ if args.warmup_steps < 0:
221
+ if args.my_pile_stage == 2:
222
+ args.warmup_steps = 10
223
+ else:
224
+ args.warmup_steps = 30
225
+ args.epoch_begin = max_p + 1
226
+
227
+ samples_per_epoch = args.epoch_steps * args.real_bsz
228
+ tokens_per_epoch = samples_per_epoch * args.ctx_len
229
+ rank_zero_info(
230
+ f"""
231
+ ############################################################################
232
+ #
233
+ # RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
234
+ #
235
+ # Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
236
+ #
237
+ # Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
238
+ #
239
+ # Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
240
+ #
241
+ # Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
242
+ #
243
+ # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
244
+ #
245
+ # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
246
+ # Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
247
+ # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
248
+ #
249
+ ############################################################################
250
+ """
251
+ )
252
+ rank_zero_info(str(vars(args)) + "\n")
253
+
254
+ assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
255
+
256
+ if args.lr_final == 0 or args.lr_init == 0:
257
+ rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
258
+
259
+ assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
260
+ os.environ["RWKV_FLOAT_MODE"] = args.precision
261
+ if args.precision == "fp32":
262
+ for i in range(10):
263
+ rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
264
+ if args.precision == "fp16":
265
+ rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
266
+
267
+ os.environ["RWKV_JIT_ON"] = "1"
268
+ if "deepspeed_stage_3" in args.strategy:
269
+ os.environ["RWKV_JIT_ON"] = "0"
270
+
271
+ torch.backends.cudnn.benchmark = True
272
+ torch.backends.cudnn.enabled = True
273
+ if args.precision == "fp32":
274
+ torch.backends.cudnn.allow_tf32 = False
275
+ torch.backends.cuda.matmul.allow_tf32 = False
276
+ else:
277
+ torch.backends.cudnn.allow_tf32 = True
278
+ torch.backends.cuda.matmul.allow_tf32 = True
279
+
280
+ if "32" in args.precision:
281
+ args.precision = 32
282
+ elif args.precision == "fp16":
283
+ args.precision = 16
284
+ else:
285
+ args.precision = "bf16"
286
+
287
+ ########################################################################################################
288
+
289
+ from src.trainer import train_callback, generate_init_weight
290
+ from src.dataset import MyDataset
291
+
292
+ train_data = MyDataset(args)
293
+ args.vocab_size = train_data.vocab_size
294
+
295
+ if args.data_type == 'wds_img':
296
+ from src.model_img import RWKV_IMG
297
+ model = RWKV_IMG(args)
298
+ else:
299
+ from src.model import RWKV
300
+ model = RWKV(args)
301
+
302
+ if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
303
+ init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
304
+ generate_init_weight(model, init_weight_name) # save initial weights
305
+ args.load_model = init_weight_name
306
+
307
+ rank_zero_info(f"########## Loading {args.load_model}... ##########")
308
+ try:
309
+ load_dict = torch.load(args.load_model, map_location="cpu")
310
+ except:
311
+ rank_zero_info(f"Bad checkpoint {args.load_model}")
312
+ if args.my_pile_stage >= 2: # try again using another checkpoint
313
+ max_p = args.my_pile_prev_p
314
+ if max_p == -1:
315
+ args.load_model = f"{args.proj_dir}/rwkv-init.pth"
316
+ else:
317
+ args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
318
+ args.epoch_begin = max_p + 1
319
+ rank_zero_info(f"Trying {args.load_model}")
320
+ load_dict = torch.load(args.load_model, map_location="cpu")
321
+
322
+ if args.load_partial == 1:
323
+ load_keys = load_dict.keys()
324
+ for k in model.state_dict():
325
+ if k not in load_keys:
326
+ load_dict[k] = model.state_dict()[k]
327
+ model.load_state_dict(load_dict)
328
+
329
+ trainer = Trainer.from_argparse_args(
330
+ args,
331
+ callbacks=[train_callback(args)],
332
+ )
333
+
334
+ if trainer.global_rank == 0:
335
+ for n in model.state_dict():
336
+ shape = model.state_dict()[n].shape
337
+ shape = [i for i in shape if i != 1]
338
+ if len(shape) > 1:
339
+ print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
340
+ else:
341
+ print(f"{str(shape[0]).ljust(5)} {n}")
342
+
343
+ if "deepspeed" in args.strategy:
344
+ trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
345
+ trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
346
+
347
+ # must set shuffle=False, persistent_workers=False (because worker is in another thread)
348
+ data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
349
+
350
+ trainer.fit(model, data_loader)
verify.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################################################################################################
2
+ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3
+ ########################################################################################################
4
+
5
+ # this is for verifying the results of different models and make sure they agree with each other
6
+
7
+ import os, sys, types
8
+ import numpy as np
9
+ import torch
10
+ np.set_printoptions(precision=4, suppress=True, linewidth=200)
11
+ try:
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
13
+ except:
14
+ pass
15
+ torch.backends.cudnn.benchmark = True
16
+ torch.backends.cudnn.allow_tf32 = False
17
+ torch.backends.cuda.matmul.allow_tf32 = False
18
+
19
+ os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
20
+ os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
21
+ RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']
22
+
23
+ TOKEN_MODE = 'pile'
24
+
25
+ if TOKEN_MODE == 'pile':
26
+ WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
27
+ MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
28
+ n_layer = 32
29
+ n_embd = 2560
30
+ ctx_len = 1024
31
+ UNKNOWN_CHAR = None
32
+
33
+ from src.utils import TOKENIZER
34
+ tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
35
+ if TOKEN_MODE == 'pile':
36
+ tokenizer.vocab_size = 50277
37
+
38
+ ########################################################################################################
39
+
40
+ os.environ["RWKV_JIT_ON"] = "1"
41
+ os.environ["RWKV_T_MAX"] = str(ctx_len)
42
+
43
+ from src.model_run import RWKV_RNN
44
+ from src.model import RWKV
45
+
46
+ args = types.SimpleNamespace()
47
+ args.vocab_size = tokenizer.vocab_size
48
+ args.ctx_len = ctx_len
49
+ args.n_embd = n_embd
50
+ args.n_layer = n_layer
51
+ args.head_qk = 0
52
+ args.pre_ffn = 0
53
+ args.grad_cp = 0
54
+ args.my_pos_emb = 0
55
+ model_train = RWKV(args).to(RUN_DEVICE)
56
+
57
+ if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
58
+ model_train = model_train.half()
59
+ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
60
+ model_train = model_train.bfloat16()
61
+
62
+ print('loading ' + MODEL_NAME)
63
+ m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
64
+ model_train.load_state_dict(m2)
65
+
66
+ if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
67
+ model_train = model_train.half()
68
+ elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
69
+ model_train = model_train.bfloat16()
70
+
71
+ args.MODEL_NAME = MODEL_NAME
72
+ args.RUN_DEVICE = RUN_DEVICE
73
+ args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
74
+ model_rnn = RWKV_RNN(args)
75
+
76
+ ########################################################################################################
77
+
78
+ print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")
79
+
80
+ # context = '\nIn a'
81
+ context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'
82
+
83
+ if TOKEN_MODE == 'pile':
84
+ ctx = tokenizer.tokenizer.encode(context)
85
+ print(f'input len {len(ctx)} data {ctx}')
86
+
87
+ ########################################################################################################
88
+
89
+ with torch.no_grad():
90
+ print('\nRWKV-train output')
91
+ out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
92
+ print(out, '\n')
93
+
94
+ print('\nRWKV-RNN output')
95
+ state = None
96
+ out = None
97
+ src_len = len(ctx)
98
+ for i in range(src_len):
99
+ x = ctx[:i+1]
100
+ out, state = model_rnn.forward(x, state)
101
+ if i < 3 or i >= src_len - 3:
102
+ print(out.detach().cpu().numpy())
103
+ if i == 2:
104
+ print('...')
zrwkv-37fifth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:426991ea8333bdc4a16fa27551b1e8e7ebe9090e2a5ff346d95290f4ffc55a3e
3
+ size 338718755