|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
np.set_printoptions(precision=4, suppress=True, linewidth=200) |
|
import types, torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import os |
|
|
|
|
|
MyModule = torch.jit.ScriptModule |
|
MyFunction = torch.jit.script_method |
|
|
|
class RWKV_TOKENIZER(): |
|
table: list[list[list[bytes]]] |
|
good: list[set[int]] |
|
wlen: list[int] |
|
def __init__(self, file_name): |
|
self.idx2token = {} |
|
sorted = [] |
|
lines = open(file_name, "r", encoding="utf-8").readlines() |
|
for l in lines: |
|
idx = int(l[:l.index(' ')]) |
|
x = eval(l[l.index(' '):l.rindex(' ')]) |
|
x = x.encode("utf-8") if isinstance(x, str) else x |
|
assert isinstance(x, bytes) |
|
assert len(x) == int(l[l.rindex(' '):]) |
|
sorted += [x] |
|
self.idx2token[idx] = x |
|
|
|
self.token2idx = {} |
|
for k, v in self.idx2token.items(): |
|
self.token2idx[v] = int(k) |
|
|
|
|
|
self.table = [[[] for j in range(256)] for i in range(256)] |
|
self.good = [set() for i in range(256)] |
|
self.wlen = [0 for i in range(256)] |
|
|
|
for i in reversed(range(len(sorted))): |
|
s = sorted[i] |
|
if len(s) >= 2: |
|
s0 = int(s[0]) |
|
s1 = int(s[1]) |
|
self.table[s0][s1] += [s] |
|
self.wlen[s0] = max(self.wlen[s0], len(s)) |
|
self.good[s0].add(s1) |
|
|
|
def encodeBytes(self, src: bytes) -> list[int]: |
|
src_len: int = len(src) |
|
tokens: list[int] = [] |
|
i: int = 0 |
|
while i < src_len: |
|
s: bytes = src[i : i + 1] |
|
|
|
if i < src_len - 1: |
|
s1: int = int(src[i + 1]) |
|
s0: int = int(src[i]) |
|
if s1 in self.good[s0]: |
|
sss: bytes = src[i : i + self.wlen[s0]] |
|
try: |
|
s = next(filter(sss.startswith, self.table[s0][s1])) |
|
except: |
|
pass |
|
tokens.append(self.token2idx[s]) |
|
i += len(s) |
|
|
|
return tokens |
|
|
|
def decodeBytes(self, tokens): |
|
return b''.join(map(lambda i: self.idx2token[i], tokens)) |
|
|
|
def encode(self, src: str): |
|
return self.encodeBytes(src.encode("utf-8")) |
|
|
|
def decode(self, tokens): |
|
return self.decodeBytes(tokens).decode('utf-8') |
|
|
|
def printTokens(self, tokens): |
|
for i in tokens: |
|
s = self.idx2token[i] |
|
try: |
|
s = s.decode('utf-8') |
|
except: |
|
pass |
|
print(f'{repr(s)}{i}', end=' ') |
|
|
|
print() |
|
|
|
|
|
|
|
def sample_logits(out, temperature=1.0, top_p=0.8): |
|
probs = F.softmax(out, dim=-1).numpy() |
|
sorted_probs = np.sort(probs)[::-1] |
|
cumulative_probs = np.cumsum(sorted_probs) |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs**(1.0 / temperature) |
|
probs = probs / np.sum(probs) |
|
out = np.random.choice(a=len(probs), p=probs) |
|
return out |
|
|
|
|
|
|
|
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") |
|
|
|
args = types.SimpleNamespace() |
|
|
|
args.MODEL_NAME = 'OnlySportsLM' |
|
|
|
|
|
args.n_layer = 20 |
|
args.n_embd =640 |
|
args.vocab_size = 65536 |
|
|
|
context ="""Kobe Bryant""" |
|
|
|
NUM_TRIALS = 1 |
|
LENGTH_PER_TRIAL = 120 |
|
TEMPERATURE = 0.5 |
|
TOP_P = 0.7 |
|
|
|
class RWKV_RNN(MyModule): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.eval() |
|
|
|
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') |
|
|
|
for k in w.keys(): |
|
w[k] = w[k].float() |
|
if '.time_' in k: w[k] = w[k].squeeze() |
|
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) |
|
|
|
self.n_head = w['blocks.0.att.time_faaaa'].shape[0] |
|
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head |
|
|
|
self.w = types.SimpleNamespace() |
|
self.w.blocks = {} |
|
for k in w.keys(): |
|
parts = k.split('.') |
|
last = parts.pop() |
|
here = self.w |
|
for p in parts: |
|
if p.isdigit(): |
|
p = int(p) |
|
if p not in here: here[p] = types.SimpleNamespace() |
|
here = here[p] |
|
else: |
|
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace()) |
|
here = getattr(here, p) |
|
setattr(here, last, w[k]) |
|
|
|
def layer_norm(self, x, w): |
|
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) |
|
|
|
@MyFunction |
|
def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw): |
|
i0 = (2+self.head_size)*i+0 |
|
sx = state[i0] - x |
|
xk = x + sx * time_maa_k |
|
xr = x + sx * time_maa_r |
|
state[i0] = x |
|
r = torch.sigmoid(rw @ xr) |
|
k = torch.square(torch.relu(kw @ xk)) |
|
return r * (vw @ k) |
|
|
|
@MyFunction |
|
def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b): |
|
H = self.n_head |
|
S = self.head_size |
|
|
|
i1 = (2+S)*i+1 |
|
sx = state[i1] - x |
|
state[i1] = x |
|
xxx = x + sx * x_maa |
|
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) |
|
xxx = torch.bmm(xxx, tm_w2).view(5, -1) |
|
mw, mk, mv, mr, mg = xxx.unbind(dim=0) |
|
|
|
xw = x + sx * (w_maa + mw) |
|
xk = x + sx * (k_maa + mk) |
|
xv = x + sx * (v_maa + mv) |
|
xr = x + sx * (r_maa + mr) |
|
xg = x + sx * (g_maa + mg) |
|
|
|
w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1) |
|
w = torch.exp(-torch.exp(w.float())) |
|
|
|
r = (rw @ xr).view(H, 1, S) |
|
k = (kw @ xk).view(H, S, 1) |
|
v = (vw @ xv).view(H, 1, S) |
|
g = F.silu(gw @ xg) |
|
|
|
s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S) |
|
|
|
x = torch.zeros(H, S) |
|
a = k @ v |
|
x = r @ (time_first * a + s) |
|
s = a + w * s |
|
|
|
state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1) |
|
x = x.flatten() |
|
|
|
x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g |
|
return ow @ x |
|
|
|
def forward(self, token, state): |
|
with torch.no_grad(): |
|
if state == None: |
|
state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd) |
|
|
|
x = self.w.emb.weight[token] |
|
x = self.layer_norm(x, self.w.blocks[0].ln0) |
|
for i in range(self.args.n_layer): |
|
att = self.w.blocks[i].att |
|
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, |
|
att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2, |
|
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay, |
|
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight, |
|
att.ln_x.weight, att.ln_x.bias) |
|
ffn = self.w.blocks[i].ffn |
|
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, |
|
ffn.time_maa_k, ffn.time_maa_r, |
|
ffn.key.weight, ffn.value.weight, ffn.receptance.weight) |
|
|
|
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out) |
|
return x.float(), state |
|
|
|
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...') |
|
model = RWKV_RNN(args) |
|
|
|
|
|
init_state = None |
|
for token in tokenizer.encode(context): |
|
init_out, init_state = model.forward(token, init_state) |
|
|
|
for TRIAL in range(NUM_TRIALS): |
|
print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="") |
|
all_tokens = [] |
|
out_last = 0 |
|
out_str = '' |
|
|
|
out, state = init_out.clone(), init_state.clone() |
|
for i in range(LENGTH_PER_TRIAL): |
|
token = sample_logits(out, TEMPERATURE, TOP_P) |
|
all_tokens += [token] |
|
try: |
|
tmp = tokenizer.decode(all_tokens[out_last:]) |
|
if '\ufffd' not in tmp: |
|
|
|
print(tmp, end="", flush=True) |
|
|
|
|
|
out_last = i + 1 |
|
if '\ufffd' not in tmp: |
|
out_str += tmp |
|
out_last = i + 1 |
|
except: |
|
pass |
|
out, state = model.forward(token, state) |
|
|
|
|
|
print('\n') |
|
|