metadata
license: apache-2.0
library_name: transformers
pipeline_tag: text-generation
tags:
- 8bit
- sharded
- open_llama
inference: false
open_llama_13b-sharded-8bit
This is open_llama_13b sharded into 2 GB shards, and in 8-bit precision using bitsandbytes==0.38.0
. Please refer to the original model card for details.
loading
pip install -U -q sentencepiece transformers accelerate bitsandbytes
load the model and tokenizer:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
model_name = "ethzanalytics/open_llama_13b-sharded-8bit"
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_fast=False)
model = LlamaForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
device_map="auto",
)