wanicca commited on
Commit
0c6166f
1 Parent(s): ee5c3c3

增加lora合并导入

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +17 -6
  3. rwkv_lora.py +325 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #python
2
+ __pycache__/
app.py CHANGED
@@ -3,6 +3,7 @@ import argparse
3
  import os, gc, torch
4
  from datetime import datetime
5
  from huggingface_hub import hf_hub_download
 
6
  # from pynvml import *
7
  # nvmlInit()
8
  # gpu_h = nvmlDeviceGetHandleByIndex(0)
@@ -14,20 +15,30 @@ parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
14
  parser.add_argument('--share',action='store_true')
15
  parser.add_argument('--ckpt',type=str,default="rwkv-loramerge_0.5-0426-v2-4096-epoch11.pth")
16
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
 
 
 
17
  args = parser.parse_args()
18
  os.environ["RWKV_JIT_ON"] = '1'
19
 
20
- from rwkv.model import RWKV
 
 
 
 
 
 
21
  if args.model_path:
22
  model_path = args.model_path
23
  else:
24
  model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename=args.ckpt)
25
- if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1':
 
26
  os.environ["RWKV_JIT_ON"] = '0'
27
  os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
28
- model = RWKV(model=model_path, strategy='cuda bf16')
29
  else:
30
- model = RWKV(model=model_path, strategy='cpu bf16')
31
  from utils import PIPELINE, PIPELINE_ARGS
32
  pipeline = PIPELINE(model, "20B_tokenizer.json")
33
 
@@ -183,6 +194,6 @@ demo = gr.TabbedInterface(
183
 
184
  demo.queue(max_size=5)
185
  if args.share:
186
- demo.launch(share=True)
187
  else:
188
- demo.launch(share=False)
 
3
  import os, gc, torch
4
  from datetime import datetime
5
  from huggingface_hub import hf_hub_download
6
+ import torch
7
  # from pynvml import *
8
  # nvmlInit()
9
  # gpu_h = nvmlDeviceGetHandleByIndex(0)
 
15
  parser.add_argument('--share',action='store_true')
16
  parser.add_argument('--ckpt',type=str,default="rwkv-loramerge_0.5-0426-v2-4096-epoch11.pth")
17
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
18
+ parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
19
+ parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
20
+ parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "25-31"')
21
  args = parser.parse_args()
22
  os.environ["RWKV_JIT_ON"] = '1'
23
 
24
+ # from rwkv.model import RWKV
25
+ from rwkv_lora import RWKV
26
+ lora_kwargs = {
27
+ "lora":args.lora,
28
+ "lora_alpha":args.lora_alpha,
29
+ "lora_layer_filter":args.lora_layer_filter
30
+ }
31
  if args.model_path:
32
  model_path = args.model_path
33
  else:
34
  model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename=args.ckpt)
35
+ # if 'ON_COLAB' in os.environ and os.environ['ON_COLAB'] == '1':
36
+ if torch.cuda.is_available() and torch.cuda.device_count()>0:
37
  os.environ["RWKV_JIT_ON"] = '0'
38
  os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
39
+ model = RWKV(model=model_path, strategy='cuda bf16',**lora_kwargs)
40
  else:
41
+ model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs)
42
  from utils import PIPELINE, PIPELINE_ARGS
43
  pipeline = PIPELINE(model, "20B_tokenizer.json")
44
 
 
194
 
195
  demo.queue(max_size=5)
196
  if args.share:
197
+ demo.launch(share=True,server_name="0.0.0.0",server_port=58888)
198
  else:
199
+ demo.launch(share=False,server_port=58888)
rwkv_lora.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Dict
3
+ import typing
4
+
5
+ from rwkv.model import RWKV as RWKV_UPSTREAM
6
+ import types, gc, os, time, re
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ def get_filter_keys(layer_filter):
11
+ if layer_filter:
12
+ layers = []
13
+ for layer in layer_filter.split(' '):
14
+ if layer.isdecimal():
15
+ layers.append(int(layer))
16
+ elif '-' in layer:
17
+ start,_,end = layer.partition('-')
18
+ start,end = int(start),int(end)
19
+ layers.extend(range(start,end+1))
20
+ else:
21
+ raise NotImplementedError("layer_filter Not implemented:",layer_filter)
22
+ layers = sorted(set(layers))
23
+ layer_prefixes = tuple(f"blocks.{l}." for l in layers)
24
+ def filter_keys(keys):
25
+ new_keys = []
26
+ for key in keys:
27
+ if key.startswith("blocks."):
28
+ if not key.startswith(layer_prefixes):
29
+ continue
30
+ new_keys.append(key)
31
+ return new_keys
32
+
33
+ else:
34
+ def filter_keys(keys):
35
+ return keys
36
+ return filter_keys
37
+
38
+ def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
39
+ print(f"Loading LoRA: {lora}")
40
+ print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}")
41
+ filter_keys = get_filter_keys(layer_filter)
42
+ w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
43
+ # merge LoRA-only slim checkpoint into the main weights
44
+ w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
45
+ # pdb.set_trace() #DEBUG
46
+ for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
47
+ w[k] = w_lora[k]
48
+ output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
49
+ # merge LoRA weights
50
+ keys = list(w.keys())
51
+ for k in keys:
52
+ if k.endswith('.weight'):
53
+ prefix = k[:-len('.weight')]
54
+ lora_A = prefix + '.lora_A'
55
+ lora_B = prefix + '.lora_B'
56
+ if lora_A in keys:
57
+ assert lora_B in keys
58
+ print(f'merging {lora_A} and {lora_B} into {k}')
59
+ assert w[lora_B].shape[1] == w[lora_A].shape[0]
60
+ lora_r = w[lora_B].shape[1]
61
+ w[k] = w[k].to(device=device)
62
+ w[lora_A] = w[lora_A].to(device=device)
63
+ w[lora_B] = w[lora_B].to(device=device)
64
+ w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
65
+ output_w[k] = w[k].to(device='cpu', copy=True)
66
+ del w[k]
67
+ del w[lora_A]
68
+ del w[lora_B]
69
+ continue
70
+
71
+ if 'lora' not in k:
72
+ print(f'retaining {k}')
73
+ output_w[k] = w[k].clone()
74
+ del w[k]
75
+ return output_w
76
+
77
+ class RWKV(RWKV_UPSTREAM):
78
+ def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None,lora=None,lora_alpha=0,lora_layer_filter=None):
79
+ super(RWKV_UPSTREAM,self).__init__()
80
+ if verbose:
81
+ prxxx = lambda *args, **kwargs: print(*args, **kwargs)
82
+ else:
83
+ prxxx = lambda *args, **kwargs: None
84
+
85
+ STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
86
+ if not re.match(STRATEGY_REGEX, strategy):
87
+ raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")
88
+
89
+ strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ')
90
+ self.args = types.SimpleNamespace()
91
+ args = self.args
92
+ args.MODEL_NAME = model
93
+ args.strategy_string = strategy
94
+
95
+ # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow)
96
+ self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0
97
+ prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n')
98
+
99
+ args.MODEL_NAME = args.MODEL_NAME.strip()
100
+ if not args.MODEL_NAME.endswith('.pth'):
101
+ args.MODEL_NAME += '.pth'
102
+ prxxx(f'Loading {args.MODEL_NAME} ...')
103
+ with torch.no_grad():
104
+ if lora:
105
+ self.w = lora_merge(base_model=args.MODEL_NAME,lora=lora,
106
+ lora_alpha=lora_alpha,layer_filter=lora_layer_filter,
107
+ device=('cuda' if 'cuda' in strategy else 'cpu'))
108
+ else:
109
+ self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first
110
+ gc.collect()
111
+ w = self.w
112
+ ALREADY_CONVERTED = False
113
+ if '_strategy' in w:
114
+ ALREADY_CONVERTED = True
115
+ assert convert_and_save_and_exit == None # you should only convert a raw model
116
+ prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n")
117
+ assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model
118
+ assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py
119
+ assert w['_rescale_layer'] == self.RESCALE_LAYER
120
+ del w['_strategy']
121
+ del w['_version']
122
+ del w['_rescale_layer']
123
+
124
+ args.n_embd = w['emb.weight'].shape[1]
125
+ args.n_layer = 0
126
+ keys = list(w.keys())
127
+ for x in keys:
128
+ layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
129
+ args.n_layer = max(args.n_layer, layer_id+1)
130
+
131
+ ####################### Compute strategy
132
+
133
+ s = [x.strip().split(' ') for x in strategy.split('->')]
134
+ plan = [0] * len(s)
135
+ stream_i = -1
136
+ stream_count = 0
137
+ to_allocate = args.n_layer + 1
138
+ allocated = 0
139
+ free_slots = 0
140
+ for i in range(len(s)):
141
+ si = s[i]
142
+ si1 = si[1]
143
+ if si1.startswith('fp32'): si[1] = [torch.float]
144
+ elif si1.startswith('fp16'): si[1] = [torch.float16]
145
+ elif si1.startswith('bf16'): si[1] = [torch.bfloat16]
146
+ if si1.endswith('i8'): si[1] += [torch.uint8]
147
+ else: si[1] += [si[1][0]]
148
+ if len(si) > 2:
149
+ ss = si[2]
150
+ assert ss.startswith('*')
151
+ if ss.endswith('+'):
152
+ plan[i] = int(ss[1:-1])
153
+ stream_i = i
154
+ else:
155
+ plan[i] = int(ss[1:])
156
+ allocated += plan[i]
157
+ if allocated >= to_allocate:
158
+ plan[i] += to_allocate - allocated
159
+ break
160
+ else:
161
+ free_slots += 1
162
+ if stream_i < 0:
163
+ if free_slots > 0 and to_allocate > allocated:
164
+ for i in range(len(s)):
165
+ if plan[i] == 0:
166
+ plan[i] = (to_allocate - allocated) // free_slots
167
+ allocated += plan[i]
168
+ free_slots -= 1
169
+ if to_allocate > allocated:
170
+ plan[len(s)-1] += to_allocate - allocated
171
+ else:
172
+ if to_allocate > allocated:
173
+ stream_count = to_allocate - allocated
174
+ plan[stream_i] += stream_count
175
+ prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)')
176
+ for i in range(len(s)):
177
+ ss = s[i]
178
+ if i != stream_i:
179
+ prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers')
180
+ else:
181
+ prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers')
182
+ plan[i] += (0 if i == 0 else plan[i-1])
183
+ self.strategy = [None] * (args.n_layer + 1)
184
+ strategy = self.strategy
185
+ for n in range(args.n_layer + 1):
186
+ for i in range(len(s)):
187
+ if n < plan[i]:
188
+ strategy[n] = types.SimpleNamespace()
189
+ strategy[n].device = s[i][0]
190
+ strategy[n].atype = s[i][1][0]
191
+ strategy[n].wtype = s[i][1][1]
192
+ strategy[n].stream = False
193
+ if i == stream_i and n >= (plan[i] - stream_count):
194
+ strategy[n].stream = True
195
+ break
196
+ prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ')
197
+ prxxx()
198
+
199
+ ####################### Load weights to self.w
200
+
201
+ if not ALREADY_CONVERTED:
202
+ try: # precompute embedding
203
+ w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias'])
204
+ except:
205
+ w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float())
206
+ del w['blocks.0.ln0.weight']
207
+ del w['blocks.0.ln0.bias']
208
+
209
+ print_need_newline = False
210
+ keys = list(w.keys())
211
+ for x in keys:
212
+ w[x].requires_grad = False
213
+ layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0
214
+ if ('ln_out.' in x) or ('head.' in x):
215
+ layer_id = args.n_layer
216
+ dd = strategy[layer_id]
217
+ DEVICE = dd.device
218
+ ATYPE = dd.atype
219
+ WTYPE = dd.wtype
220
+
221
+ if not ALREADY_CONVERTED:
222
+ if self.RESCALE_LAYER > 0:
223
+ if 'att.output.weight' in x:
224
+ w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
225
+ if 'ffn.value.weight' in x:
226
+ w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
227
+
228
+ if '.time_' in x:
229
+ w[x] = w[x].squeeze()
230
+ if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x:
231
+ w[x] = w[x].t()
232
+
233
+ if '.time_decay' in x: # need fp32 for this
234
+ w[x] = -torch.exp(w[x].float())
235
+ elif '.time_first' in x: # need fp32 for this
236
+ w[x] = w[x].float()
237
+ else:
238
+ if (len(w[x].shape) == 2) and ('emb' not in x):
239
+ if WTYPE != torch.uint8:
240
+ w[x] = w[x].to(dtype=WTYPE)
241
+ else:
242
+ w[x] = w[x].float()
243
+
244
+ if w[x].shape[0] > w[x].shape[1]:
245
+ w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1)
246
+ w[x] = w[x] - w[x+'_my']
247
+ w[x+'_mx'] = torch.amin(w[x], dim=0)
248
+ w[x] = w[x] - w[x+'_mx']
249
+ w[x+'_rx'] = torch.amax(w[x], dim=0)
250
+ w[x] = w[x] / w[x+'_rx']
251
+ w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1)
252
+ w[x] = w[x] / w[x+'_ry']
253
+ else:
254
+ w[x+'_mx'] = torch.amin(w[x], dim=0)
255
+ w[x] = w[x] - w[x+'_mx']
256
+ w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1)
257
+ w[x] = w[x] - w[x+'_my']
258
+ w[x+'_rx'] = torch.amax(w[x], dim=0)
259
+ w[x] = w[x] / w[x+'_rx']
260
+ w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1)
261
+ w[x] = w[x] / w[x+'_ry']
262
+
263
+ w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8)
264
+ w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous()
265
+ w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous()
266
+ w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous()
267
+ w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous()
268
+ else:
269
+ w[x] = w[x].to(dtype=ATYPE)
270
+
271
+ if convert_and_save_and_exit == None:
272
+ if 'emb.' in x:
273
+ w[x] = w[x].contiguous()
274
+ elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')):
275
+ try:
276
+ w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :)
277
+ except:
278
+ print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.')
279
+ elif DEVICE != 'cpu':
280
+ w[x] = w[x].to(device=DEVICE).contiguous()
281
+
282
+ if (dd.stream) or (DEVICE != 'cpu'):
283
+ try:
284
+ w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous()
285
+ w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous()
286
+ w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous()
287
+ w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous()
288
+ except:
289
+ pass
290
+
291
+ if 'ffn.value.weight' in x:
292
+ gc.collect()
293
+ if 'cuda' in args.strategy_string:
294
+ torch.cuda.empty_cache()
295
+
296
+ shape = [i for i in w[x].shape if i != 1]
297
+ if len(shape) > 1:
298
+ shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
299
+ else:
300
+ shape = f" {str(shape[0]).rjust(5)} "
301
+ if layer_id == 0 or layer_id >= args.n_layer-1:
302
+ if print_need_newline:
303
+ prxxx('\n', end = '')
304
+ print_need_newline = False
305
+ dt = str(w[x].dtype).replace('torch.', '')
306
+ dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8')
307
+ prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '')
308
+ else:
309
+ print_need_newline = True
310
+ prxxx('.', end = '', flush = True)
311
+
312
+ if convert_and_save_and_exit:
313
+ w['_strategy'] = args.strategy_string
314
+ w['_rescale_layer'] = self.RESCALE_LAYER
315
+ w['_version'] = '0.7'
316
+ if not convert_and_save_and_exit.endswith('.pth'):
317
+ convert_and_save_and_exit += '.pth'
318
+ prxxx(f'Saving to {convert_and_save_and_exit}...')
319
+ torch.save(w, convert_and_save_and_exit)
320
+ prxxx(f'Converted and saved. Now this will exit.')
321
+ exit(0)
322
+
323
+ gc.collect()
324
+ if 'cuda' in args.strategy_string:
325
+ torch.cuda.empty_cache()