|
|
|
|
|
|
|
|
|
import numpy as np |
|
import math, os, sys, types, time, gc |
|
import torch |
|
from src.utils import TOKENIZER |
|
try: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] |
|
except: |
|
pass |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
np.set_printoptions(precision=4, suppress=True, linewidth=200) |
|
args = types.SimpleNamespace() |
|
|
|
|
|
|
|
|
|
|
|
|
|
args.RUN_DEVICE = "cpu" |
|
args.FLOAT_MODE = "fp32" |
|
|
|
|
|
|
|
os.environ["RWKV_JIT_ON"] = '1' |
|
|
|
TOKEN_MODE = "pile" |
|
WORD_NAME = [ |
|
"20B_tokenizer_openchatgpt.json", |
|
"20B_tokenizer_openchatgpt.json", |
|
] |
|
UNKNOWN_CHAR = None |
|
vocab_size = 50277 + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = './out2/rwkv-5' |
|
n_layer = 24 |
|
n_embd = 1024 |
|
ctx_len = 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.MODEL_NAME = MODEL_NAME |
|
args.n_layer = n_layer |
|
args.n_embd = n_embd |
|
args.ctx_len = ctx_len |
|
args.vocab_size = vocab_size |
|
args.head_qk = 0 |
|
args.pre_ffn = 0 |
|
args.grad_cp = 0 |
|
args.my_pos_emb = 0 |
|
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE |
|
|
|
|
|
|
|
|
|
|
|
context = """quality: high |
|
|
|
[System] |
|
Assistant is a distilled language model trained by the community.<|STK_SP|> |
|
|
|
[System] |
|
<|STK_SP|> |
|
|
|
[User] |
|
Hi!<|STK_SP|> |
|
|
|
[Assistant] |
|
""" |
|
|
|
NUM_TRIALS = 999 |
|
LENGTH_PER_TRIAL = 333 |
|
|
|
TEMPERATURE = 1.0 |
|
top_p = 0.8 |
|
top_p_newline = 0.9 |
|
|
|
DEBUG_DEBUG = False |
|
|
|
|
|
|
|
print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...') |
|
from src.model_run import RWKV_RNN |
|
|
|
model = RWKV_RNN(args) |
|
|
|
print(f'\nOptimizing speed...') |
|
out, _ = model.forward([187], None) |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
print(f'\nLoading tokenizer {WORD_NAME}...') |
|
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) |
|
if TOKEN_MODE == "pile": |
|
assert tokenizer.tokenizer.decode([187]) == '\n' |
|
|
|
|
|
|
|
if tokenizer.charMode: |
|
context = tokenizer.refine_context(context) |
|
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] |
|
else: |
|
ctx = tokenizer.tokenizer.encode(context) |
|
src_len = len(ctx) |
|
src_ctx = ctx.copy() |
|
|
|
print("\nYour prompt has " + str(src_len) + " tokens.") |
|
print( |
|
"Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" |
|
) |
|
|
|
time_slot = {} |
|
time_ref = time.time_ns() |
|
|
|
def record_time(name): |
|
if name not in time_slot: |
|
time_slot[name] = 1e20 |
|
tt = (time.time_ns() - time_ref) / 1e9 |
|
if tt < time_slot[name]: |
|
time_slot[name] = tt |
|
|
|
init_state = None |
|
init_out = None |
|
state = None |
|
out = None |
|
|
|
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): |
|
print(("-" * 50) + '\n' + context, end="") |
|
|
|
time_ref = time.time_ns() |
|
ctx = src_ctx.copy() |
|
|
|
if TRIAL == 0: |
|
for i in range(src_len): |
|
x = ctx[: i + 1] |
|
if i == src_len - 1: |
|
init_out, init_state = model.forward(x, init_state) |
|
else: |
|
init_state = model.forward(x, init_state, preprocess_only=True) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
record_time('preprocess') |
|
out_last = src_len |
|
for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): |
|
x = ctx[: i + 1] |
|
x = x[-ctx_len:] |
|
|
|
if i == src_len: |
|
out = init_out.clone() |
|
state = init_state.clone() |
|
else: |
|
out, state = model.forward(x, state) |
|
if DEBUG_DEBUG: |
|
print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) |
|
if TOKEN_MODE == "pile": |
|
out[0] = -999999999 |
|
|
|
ttt = tokenizer.sample_logits( |
|
out, |
|
x, |
|
ctx_len, |
|
temperature=TEMPERATURE, |
|
top_p_usual=top_p, |
|
top_p_newline=top_p_newline, |
|
) |
|
ctx += [ttt] |
|
|
|
if ttt == vocab_size - 1: |
|
break |
|
|
|
if tokenizer.charMode: |
|
char = tokenizer.itos[ttt] |
|
print(char, end="", flush=True) |
|
else: |
|
char = tokenizer.tokenizer.decode(ctx[out_last:]) |
|
if '\ufffd' not in char: |
|
print(char, end="", flush=True) |
|
out_last = i+1 |
|
|
|
record_time('total') |
|
|
|
print( |
|
f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' |
|
) |
|
|
|
print(("-" * 50) + '\n') |
|
|