Model card for RWKV-4 | 3B parameters trained on Pile dataset
RWKV is a project led by Bo Peng. Learn more about the model architecture in the blogposts from Johan Wind here and here. Learn more about the project by joining the RWKV discord server.
Table of contents
TL;DR
Below is the description from the original repository
RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). It's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Model Details
The details of the architecture can be found on the blogpost mentioned above and the Hugging Face blogpost of the integration.
Usage
Convert the raw weights to the HF format
You can use the convert_rwkv_checkpoint_to_hf.py
script by specifying the repo_id of the original weights, the filename and the output directory. You can also optionally directly push the converted model on the Hub by passing --push_to_hub
flag and --model_name
argument to specify where to push the converted weights.
python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR --push_to_hub --model_name dummy_user/converted-rwkv
Generate text
You can use the AutoModelForCausalLM
and AutoTokenizer
classes to generate texts from the model. Expand the sections below to understand how to run the model in different scenarios:
Running the model on a CPU
Click to expand
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
Running the model on a single GPU
Click to expand
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile").to(0)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
Running the model in half-precision, on GPU
Click to expand
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile", torch_dtype=torch.float16).to(0)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
Running the model multiple GPUs
Click to expand
# pip install accelerate
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-3b-pile", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-3b-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
Citation
If you use this model, please consider citing the original work, from the original repo here
- Downloads last month
- 793