增加lora载入时去除部分模块的正则表达式写法
Browse files- app.py +1 -1
- rwkv_lora.py +23 -2
app.py
CHANGED
@@ -17,7 +17,7 @@ parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch
|
|
17 |
parser.add_argument('--model_path',type=str,default=None,help="local model path")
|
18 |
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
|
19 |
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
|
20 |
-
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "25-31"')
|
21 |
args = parser.parse_args()
|
22 |
os.environ["RWKV_JIT_ON"] = '1'
|
23 |
|
|
|
17 |
parser.add_argument('--model_path',type=str,default=None,help="local model path")
|
18 |
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path')
|
19 |
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha')
|
20 |
+
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"')
|
21 |
args = parser.parse_args()
|
22 |
os.environ["RWKV_JIT_ON"] = '1'
|
23 |
|
rwkv_lora.py
CHANGED
@@ -7,11 +7,21 @@ import types, gc, os, time, re
|
|
7 |
import torch
|
8 |
from torch.nn import functional as F
|
9 |
|
|
|
10 |
def get_filter_keys_and_merge_coef(layer_filter):
|
11 |
if layer_filter:
|
12 |
layers = []
|
13 |
layer_coef = {}
|
|
|
14 |
for layer in layer_filter.split(' '):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
if '*' in layer:
|
16 |
coef,_,layer = layer.partition('*')
|
17 |
coef = float(coef)
|
@@ -20,22 +30,31 @@ def get_filter_keys_and_merge_coef(layer_filter):
|
|
20 |
if layer.isdecimal():
|
21 |
layers.append(int(layer))
|
22 |
layer_coef[int(layer)]=coef
|
|
|
23 |
elif '-' in layer:
|
24 |
start,_,end = layer.partition('-')
|
25 |
start,end = int(start),int(end)
|
26 |
layers.extend(range(start,end+1))
|
27 |
for l in range(start,end+1):
|
28 |
layer_coef[l] = coef
|
|
|
29 |
else:
|
30 |
raise NotImplementedError("layer_filter Not implemented:",layer_filter)
|
31 |
layers = sorted(set(layers))
|
32 |
-
layer_prefixes = tuple(f"blocks.{l}." for l in layers)
|
33 |
def filter_keys(keys):
|
34 |
new_keys = []
|
35 |
for key in keys:
|
|
|
|
|
36 |
if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
|
37 |
-
|
|
|
|
|
|
|
38 |
continue
|
|
|
|
|
39 |
new_keys.append(key)
|
40 |
return new_keys
|
41 |
def merge_coef(key):
|
@@ -59,6 +78,8 @@ def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
|
|
59 |
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
60 |
# pdb.set_trace() #DEBUG
|
61 |
for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
|
|
|
|
|
62 |
w[k] = w_lora[k]
|
63 |
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
64 |
# merge LoRA weights
|
|
|
7 |
import torch
|
8 |
from torch.nn import functional as F
|
9 |
|
10 |
+
# valid_filter_pattern = r"(((\d+\.\d+\*)?(\d+)(-\d+)?(/\S+)?|(/\S+))(\s+|$))+"
|
11 |
def get_filter_keys_and_merge_coef(layer_filter):
|
12 |
if layer_filter:
|
13 |
layers = []
|
14 |
layer_coef = {}
|
15 |
+
layer_remove_patterns = {}
|
16 |
for layer in layer_filter.split(' '):
|
17 |
+
if '/' in layer: #过滤pattern,需要写成正则表达式
|
18 |
+
layer,_,remove_pattern = layer.partition('/')
|
19 |
+
remove_pattern = re.compile(remove_pattern)
|
20 |
+
else:
|
21 |
+
remove_pattern = None
|
22 |
+
if layer=='':
|
23 |
+
layer_remove_patterns['global']=remove_pattern
|
24 |
+
continue
|
25 |
if '*' in layer:
|
26 |
coef,_,layer = layer.partition('*')
|
27 |
coef = float(coef)
|
|
|
30 |
if layer.isdecimal():
|
31 |
layers.append(int(layer))
|
32 |
layer_coef[int(layer)]=coef
|
33 |
+
layer_remove_patterns[int(layer)]=remove_pattern
|
34 |
elif '-' in layer:
|
35 |
start,_,end = layer.partition('-')
|
36 |
start,end = int(start),int(end)
|
37 |
layers.extend(range(start,end+1))
|
38 |
for l in range(start,end+1):
|
39 |
layer_coef[l] = coef
|
40 |
+
layer_remove_patterns[l]=remove_pattern
|
41 |
else:
|
42 |
raise NotImplementedError("layer_filter Not implemented:",layer_filter)
|
43 |
layers = sorted(set(layers))
|
44 |
+
# layer_prefixes = tuple(f"blocks.{l}." for l in layers)
|
45 |
def filter_keys(keys):
|
46 |
new_keys = []
|
47 |
for key in keys:
|
48 |
+
if layer_remove_patterns.get("global") and layer_remove_patterns['global'].search(key):
|
49 |
+
continue #符合全局去除规则
|
50 |
if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
|
51 |
+
l = int(key.split('.')[1])
|
52 |
+
if l not in layers: #不在允许层,过滤掉
|
53 |
+
continue
|
54 |
+
if layer_remove_patterns[l] and layer_remove_patterns[l].search(key): #符合对应层的去除规则,过滤掉
|
55 |
continue
|
56 |
+
# if not key.startswith(layer_prefixes):
|
57 |
+
# continue
|
58 |
new_keys.append(key)
|
59 |
return new_keys
|
60 |
def merge_coef(key):
|
|
|
78 |
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
79 |
# pdb.set_trace() #DEBUG
|
80 |
for k in filter_keys(w_lora.keys()): #处理time_mixing之类的融合
|
81 |
+
if k in w:
|
82 |
+
print(f"replacing {k}")
|
83 |
w[k] = w_lora[k]
|
84 |
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
85 |
# merge LoRA weights
|