|
from collections import OrderedDict |
|
from typing import Dict |
|
import typing |
|
|
|
from rwkv.model import RWKV as RWKV_UPSTREAM |
|
import types, gc, os, time, re |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
def get_filter_keys_and_merge_coef(layer_filter): |
|
if layer_filter: |
|
layers = [] |
|
layer_coef = {} |
|
for layer in layer_filter.split(' '): |
|
if '*' in layer: |
|
coef,_,layer = layer.partition('*') |
|
coef = float(coef) |
|
else: |
|
coef = 1 |
|
if layer.isdecimal(): |
|
layers.append(int(layer)) |
|
layer_coef[int(layer)]=coef |
|
elif '-' in layer: |
|
start,_,end = layer.partition('-') |
|
start,end = int(start),int(end) |
|
layers.extend(range(start,end+1)) |
|
for l in range(start,end+1): |
|
layer_coef[l] = coef |
|
else: |
|
raise NotImplementedError("layer_filter Not implemented:",layer_filter) |
|
layers = sorted(set(layers)) |
|
layer_prefixes = tuple(f"blocks.{l}." for l in layers) |
|
def filter_keys(keys): |
|
new_keys = [] |
|
for key in keys: |
|
if key.startswith("blocks."): |
|
if not key.startswith(layer_prefixes): |
|
continue |
|
new_keys.append(key) |
|
return new_keys |
|
def merge_coef(key): |
|
if key.startswith('blocks.') and int(key.split('.')[1]) in layer_coef: |
|
return layer_coef[int(key.split('.')[1])] |
|
else: |
|
return 1 |
|
else: |
|
def filter_keys(keys): |
|
return keys |
|
def merge_coef(key): |
|
return 1 |
|
return filter_keys,merge_coef |
|
|
|
def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,): |
|
print(f"Loading LoRA: {lora}") |
|
print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}") |
|
filter_keys,merge_coef = get_filter_keys_and_merge_coef(layer_filter) |
|
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') |
|
|
|
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') |
|
|
|
for k in filter_keys(w_lora.keys()): |
|
w[k] = w_lora[k] |
|
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() |
|
|
|
keys = list(w.keys()) |
|
for k in keys: |
|
if k.endswith('.weight'): |
|
prefix = k[:-len('.weight')] |
|
lora_A = prefix + '.lora_A' |
|
lora_B = prefix + '.lora_B' |
|
if lora_A in keys: |
|
assert lora_B in keys |
|
print(f'merging {lora_A} and {lora_B} into {k}') |
|
assert w[lora_B].shape[1] == w[lora_A].shape[0] |
|
lora_r = w[lora_B].shape[1] |
|
w[k] = w[k].to(device=device) |
|
w[lora_A] = w[lora_A].to(device=device) |
|
w[lora_B] = w[lora_B].to(device=device) |
|
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) * merge_coef(k) |
|
output_w[k] = w[k].to(device='cpu', copy=True) |
|
del w[k] |
|
del w[lora_A] |
|
del w[lora_B] |
|
continue |
|
|
|
if 'lora' not in k: |
|
print(f'retaining {k}') |
|
output_w[k] = w[k].clone() |
|
del w[k] |
|
return output_w |
|
|
|
class RWKV(RWKV_UPSTREAM): |
|
def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None,lora=None,lora_alpha=0,lora_layer_filter=None): |
|
super(RWKV_UPSTREAM,self).__init__() |
|
if verbose: |
|
prxxx = lambda *args, **kwargs: print(*args, **kwargs) |
|
else: |
|
prxxx = lambda *args, **kwargs: None |
|
|
|
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" |
|
if not re.match(STRATEGY_REGEX, strategy): |
|
raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") |
|
|
|
strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') |
|
self.args = types.SimpleNamespace() |
|
args = self.args |
|
args.MODEL_NAME = model |
|
args.strategy_string = strategy |
|
|
|
|
|
self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 |
|
prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') |
|
|
|
args.MODEL_NAME = args.MODEL_NAME.strip() |
|
if not args.MODEL_NAME.endswith('.pth'): |
|
args.MODEL_NAME += '.pth' |
|
prxxx(f'Loading {args.MODEL_NAME} ...') |
|
with torch.no_grad(): |
|
if lora: |
|
self.w = lora_merge(base_model=args.MODEL_NAME,lora=lora, |
|
lora_alpha=lora_alpha,layer_filter=lora_layer_filter, |
|
device=('cuda' if 'cuda' in strategy else 'cpu')) |
|
else: |
|
self.w = torch.load(args.MODEL_NAME, map_location='cpu') |
|
gc.collect() |
|
w = self.w |
|
ALREADY_CONVERTED = False |
|
if '_strategy' in w: |
|
ALREADY_CONVERTED = True |
|
assert convert_and_save_and_exit == None |
|
prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") |
|
assert w['_strategy'] == args.strategy_string |
|
assert float(w['_version']) >= 0.7 |
|
assert w['_rescale_layer'] == self.RESCALE_LAYER |
|
del w['_strategy'] |
|
del w['_version'] |
|
del w['_rescale_layer'] |
|
|
|
args.n_embd = w['emb.weight'].shape[1] |
|
args.n_layer = 0 |
|
keys = list(w.keys()) |
|
for x in keys: |
|
layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 |
|
args.n_layer = max(args.n_layer, layer_id+1) |
|
|
|
|
|
|
|
s = [x.strip().split(' ') for x in strategy.split('->')] |
|
plan = [0] * len(s) |
|
stream_i = -1 |
|
stream_count = 0 |
|
to_allocate = args.n_layer + 1 |
|
allocated = 0 |
|
free_slots = 0 |
|
for i in range(len(s)): |
|
si = s[i] |
|
si1 = si[1] |
|
if si1.startswith('fp32'): si[1] = [torch.float] |
|
elif si1.startswith('fp16'): si[1] = [torch.float16] |
|
elif si1.startswith('bf16'): si[1] = [torch.bfloat16] |
|
if si1.endswith('i8'): si[1] += [torch.uint8] |
|
else: si[1] += [si[1][0]] |
|
if len(si) > 2: |
|
ss = si[2] |
|
assert ss.startswith('*') |
|
if ss.endswith('+'): |
|
plan[i] = int(ss[1:-1]) |
|
stream_i = i |
|
else: |
|
plan[i] = int(ss[1:]) |
|
allocated += plan[i] |
|
if allocated >= to_allocate: |
|
plan[i] += to_allocate - allocated |
|
break |
|
else: |
|
free_slots += 1 |
|
if stream_i < 0: |
|
if free_slots > 0 and to_allocate > allocated: |
|
for i in range(len(s)): |
|
if plan[i] == 0: |
|
plan[i] = (to_allocate - allocated) // free_slots |
|
allocated += plan[i] |
|
free_slots -= 1 |
|
if to_allocate > allocated: |
|
plan[len(s)-1] += to_allocate - allocated |
|
else: |
|
if to_allocate > allocated: |
|
stream_count = to_allocate - allocated |
|
plan[stream_i] += stream_count |
|
prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') |
|
for i in range(len(s)): |
|
ss = s[i] |
|
if i != stream_i: |
|
prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') |
|
else: |
|
prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') |
|
plan[i] += (0 if i == 0 else plan[i-1]) |
|
self.strategy = [None] * (args.n_layer + 1) |
|
strategy = self.strategy |
|
for n in range(args.n_layer + 1): |
|
for i in range(len(s)): |
|
if n < plan[i]: |
|
strategy[n] = types.SimpleNamespace() |
|
strategy[n].device = s[i][0] |
|
strategy[n].atype = s[i][1][0] |
|
strategy[n].wtype = s[i][1][1] |
|
strategy[n].stream = False |
|
if i == stream_i and n >= (plan[i] - stream_count): |
|
strategy[n].stream = True |
|
break |
|
prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') |
|
prxxx() |
|
|
|
|
|
|
|
if not ALREADY_CONVERTED: |
|
try: |
|
w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) |
|
except: |
|
w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) |
|
del w['blocks.0.ln0.weight'] |
|
del w['blocks.0.ln0.bias'] |
|
|
|
print_need_newline = False |
|
keys = list(w.keys()) |
|
for x in keys: |
|
w[x].requires_grad = False |
|
layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 |
|
if ('ln_out.' in x) or ('head.' in x): |
|
layer_id = args.n_layer |
|
dd = strategy[layer_id] |
|
DEVICE = dd.device |
|
ATYPE = dd.atype |
|
WTYPE = dd.wtype |
|
|
|
if not ALREADY_CONVERTED: |
|
if self.RESCALE_LAYER > 0: |
|
if 'att.output.weight' in x: |
|
w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) |
|
if 'ffn.value.weight' in x: |
|
w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) |
|
|
|
if '.time_' in x: |
|
w[x] = w[x].squeeze() |
|
if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: |
|
w[x] = w[x].t() |
|
|
|
if '.time_decay' in x: |
|
w[x] = -torch.exp(w[x].float()) |
|
elif '.time_first' in x: |
|
w[x] = w[x].float() |
|
else: |
|
if (len(w[x].shape) == 2) and ('emb' not in x): |
|
if WTYPE != torch.uint8: |
|
w[x] = w[x].to(dtype=WTYPE) |
|
else: |
|
w[x] = w[x].float() |
|
|
|
if w[x].shape[0] > w[x].shape[1]: |
|
w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) |
|
w[x] = w[x] - w[x+'_my'] |
|
w[x+'_mx'] = torch.amin(w[x], dim=0) |
|
w[x] = w[x] - w[x+'_mx'] |
|
w[x+'_rx'] = torch.amax(w[x], dim=0) |
|
w[x] = w[x] / w[x+'_rx'] |
|
w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) |
|
w[x] = w[x] / w[x+'_ry'] |
|
else: |
|
w[x+'_mx'] = torch.amin(w[x], dim=0) |
|
w[x] = w[x] - w[x+'_mx'] |
|
w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) |
|
w[x] = w[x] - w[x+'_my'] |
|
w[x+'_rx'] = torch.amax(w[x], dim=0) |
|
w[x] = w[x] / w[x+'_rx'] |
|
w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) |
|
w[x] = w[x] / w[x+'_ry'] |
|
|
|
w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) |
|
w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() |
|
w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() |
|
w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() |
|
w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() |
|
else: |
|
w[x] = w[x].to(dtype=ATYPE) |
|
|
|
if convert_and_save_and_exit == None: |
|
if 'emb.' in x: |
|
w[x] = w[x].contiguous() |
|
elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): |
|
try: |
|
w[x] = w[x].contiguous().pin_memory() |
|
except: |
|
print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') |
|
elif DEVICE != 'cpu': |
|
w[x] = w[x].to(device=DEVICE).contiguous() |
|
|
|
if (dd.stream) or (DEVICE != 'cpu'): |
|
try: |
|
w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() |
|
w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() |
|
w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() |
|
w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() |
|
except: |
|
pass |
|
|
|
if 'ffn.value.weight' in x: |
|
gc.collect() |
|
if 'cuda' in args.strategy_string: |
|
torch.cuda.empty_cache() |
|
|
|
shape = [i for i in w[x].shape if i != 1] |
|
if len(shape) > 1: |
|
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" |
|
else: |
|
shape = f" {str(shape[0]).rjust(5)} " |
|
if layer_id == 0 or layer_id >= args.n_layer-1: |
|
if print_need_newline: |
|
prxxx('\n', end = '') |
|
print_need_newline = False |
|
dt = str(w[x].dtype).replace('torch.', '') |
|
dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') |
|
prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') |
|
else: |
|
print_need_newline = True |
|
prxxx('.', end = '', flush = True) |
|
|
|
if convert_and_save_and_exit: |
|
w['_strategy'] = args.strategy_string |
|
w['_rescale_layer'] = self.RESCALE_LAYER |
|
w['_version'] = '0.7' |
|
if not convert_and_save_and_exit.endswith('.pth'): |
|
convert_and_save_and_exit += '.pth' |
|
prxxx(f'Saving to {convert_and_save_and_exit}...') |
|
torch.save(w, convert_and_save_and_exit) |
|
prxxx(f'Converted and saved. Now this will exit.') |
|
exit(0) |
|
|
|
gc.collect() |
|
if 'cuda' in args.strategy_string: |
|
torch.cuda.empty_cache() |