Spaces:
Sleeping
Sleeping
File size: 9,492 Bytes
54200b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from model import *
from utils import *
from tokenizer import *
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
if torch._dynamo.is_compiling():
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
distribution = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0:
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
return multinomial_num_samples_1(probs)
return torch.argmax(logits, dim=-1, keepdim=True)
def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
logits = model(x, input_pos)
next = sample(logits, **kwargs)
return next.type_as(x)
@torch.inference_mode()
def generate(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")
device = prompt.device
tokens = [prompt]
input_pos = torch.tensor([T], device=device)
token = next_token(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()
tokens.append(token)
if token == eos_id:
break
input_pos = input_pos.add_(1)
return torch.cat(tokens)
def main(
prompt: str = "What food do llamas eat?",
*,
num_samples: int = 1,
max_new_tokens: int = 50,
top_k: Optional[int] = 200,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
strategy: str = "auto",
devices: int = 1,
precision: Optional[str] = None,
compile: bool = False,
model_name: str = "pythia_160m_hf"
) -> None:
"""Generates text samples based on a pre-trained model and tokenizer.
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
checkpoint_dir: The checkpoint directory to load.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
strategy: Indicates the Fabric strategy setting to use.
devices: How many devices to use.
precision: Indicates the Fabric precision setting to use.
compile: Whether to compile the model.
"""
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None:
if devices > 1:
raise NotImplementedError(
"Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the"
" --quantize flag."
)
if quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir, model_name)
config = Config.from_json(checkpoint_dir / "lit_config.json")
if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
if model_name == "pythia_160m_deduped_huggingface":
model_file = "pythia_160m_deduped_hf.pth"
elif model_name == "pythia_160m_deduped_custom":
model_file = "pythia_160m_deduped_custom.pth"
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
tokenizer = Tokenizer(checkpoint_dir)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead")
model = fabric.setup_module(model)
t0 = time.perf_counter()
load_checkpoint(fabric, model, checkpoint_path)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
L.seed_everything(1234)
print(f'num_samples is {num_samples}')
output_msg_list = []
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
output_msg = tokenizer.decode(y)
tokens_generated = y.size(0) - prompt_length
output_msg_list.append(output_msg)
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
return output_msg_list
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
output_msg_list = CLI(main) |