wanicca commited on
Commit
0f9dd39
1 Parent(s): 4a4232a

增加lora载入时去除部分模块的正则表达式写法

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. rwkv_lora.py +23 -2
app.py CHANGED
@@ -17,7 +17,7 @@ parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch
17
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
18
  parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
19
  parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
20
- parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "25-31"')
21
  args = parser.parse_args()
22
  os.environ["RWKV_JIT_ON"] = '1'
23
 
 
17
  parser.add_argument('--model_path',type=str,default=None,help="local model path")
18
  parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
19
  parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
20
+ parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"')
21
  args = parser.parse_args()
22
  os.environ["RWKV_JIT_ON"] = '1'
23
 
rwkv_lora.py CHANGED
@@ -7,11 +7,21 @@ import types, gc, os, time, re
7
  import torch
8
  from torch.nn import functional as F
9
 
 
10
  def get_filter_keys_and_merge_coef(layer_filter):
11
  if layer_filter:
12
  layers = []
13
  layer_coef = {}
 
14
  for layer in layer_filter.split(' '):
 
 
 
 
 
 
 
 
15
  if '*' in layer:
16
  coef,_,layer = layer.partition('*')
17
  coef = float(coef)
@@ -20,22 +30,31 @@ def get_filter_keys_and_merge_coef(layer_filter):
20
  if layer.isdecimal():
21
  layers.append(int(layer))
22
  layer_coef[int(layer)]=coef
 
23
  elif '-' in layer:
24
  start,_,end = layer.partition('-')
25
  start,end = int(start),int(end)
26
  layers.extend(range(start,end+1))
27
  for l in range(start,end+1):
28
  layer_coef[l] = coef
 
29
  else:
30
  raise NotImplementedError("layer_filter Not implemented:",layer_filter)
31
  layers = sorted(set(layers))
32
- layer_prefixes = tuple(f"blocks.{l}." for l in layers)
33
  def filter_keys(keys):
34
  new_keys = []
35
  for key in keys:
 
 
36
  if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
37
- if not key.startswith(layer_prefixes):
 
 
 
38
  continue
 
 
39
  new_keys.append(key)
40
  return new_keys
41
  def merge_coef(key):
@@ -59,6 +78,8 @@ def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
59
  w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
60
  # pdb.set_trace() #DEBUG
61
  for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
 
 
62
  w[k] = w_lora[k]
63
  output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
64
  # merge LoRA weights
 
7
  import torch
8
  from torch.nn import functional as F
9
 
10
+ # valid_filter_pattern = r"(((\d+\.\d+\*)?(\d+)(-\d+)?(/\S+)?|(/\S+))(\s+|$))+"
11
  def get_filter_keys_and_merge_coef(layer_filter):
12
  if layer_filter:
13
  layers = []
14
  layer_coef = {}
15
+ layer_remove_patterns = {}
16
  for layer in layer_filter.split(' '):
17
+ if '/' in layer: #过滤pattern,需要写成正则表达式
18
+ layer,_,remove_pattern = layer.partition('/')
19
+ remove_pattern = re.compile(remove_pattern)
20
+ else:
21
+ remove_pattern = None
22
+ if layer=='':
23
+ layer_remove_patterns['global']=remove_pattern
24
+ continue
25
  if '*' in layer:
26
  coef,_,layer = layer.partition('*')
27
  coef = float(coef)
 
30
  if layer.isdecimal():
31
  layers.append(int(layer))
32
  layer_coef[int(layer)]=coef
33
+ layer_remove_patterns[int(layer)]=remove_pattern
34
  elif '-' in layer:
35
  start,_,end = layer.partition('-')
36
  start,end = int(start),int(end)
37
  layers.extend(range(start,end+1))
38
  for l in range(start,end+1):
39
  layer_coef[l] = coef
40
+ layer_remove_patterns[l]=remove_pattern
41
  else:
42
  raise NotImplementedError("layer_filter Not implemented:",layer_filter)
43
  layers = sorted(set(layers))
44
+ # layer_prefixes = tuple(f"blocks.{l}." for l in layers)
45
  def filter_keys(keys):
46
  new_keys = []
47
  for key in keys:
48
+ if layer_remove_patterns.get("global") and layer_remove_patterns['global'].search(key):
49
+ continue #符合全局去除规则
50
  if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
51
+ l = int(key.split('.')[1])
52
+ if l not in layers: #不在允许层,过滤掉
53
+ continue
54
+ if layer_remove_patterns[l] and layer_remove_patterns[l].search(key): #符合对应层的去除规则,过滤掉
55
  continue
56
+ # if not key.startswith(layer_prefixes):
57
+ # continue
58
  new_keys.append(key)
59
  return new_keys
60
  def merge_coef(key):
 
78
  w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
79
  # pdb.set_trace() #DEBUG
80
  for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
81
+ if k in w:
82
+ print(f"replacing {k}")
83
  w[k] = w_lora[k]
84
  output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
85
  # merge LoRA weights