File size: 8,197 Bytes
c5a6a24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from tsai_gpt.model import GPT, Block, Config
from tsai_gpt.tokenizer import Tokenizer
from tsai_gpt.utils import (get_default_supported_precision, gptq_quantization,
load_checkpoint)
L.seed_everything(1234)
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
if torch._dynamo.is_compiling():
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
distribution = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
) -> torch.Tensor:
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0:
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
return multinomial_num_samples_1(probs)
return torch.argmax(logits, dim=-1, keepdim=True)
def next_token(
model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
logits = model(x, input_pos)
next = sample(logits, **kwargs)
return next.type_as(x)
@torch.inference_mode()
def generate(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
device = prompt.device
tokens = [prompt]
input_pos = torch.tensor([T], device=device)
token = next_token(
model,
torch.arange(0, T, device=device),
prompt.view(1, -1),
temperature=temperature,
top_k=top_k,
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
).clone()
tokens.append(token)
if token == eos_id:
break
input_pos = input_pos.add_(1)
return torch.cat(tokens)
"""
quantize (Optional[Literal["bnb.nf4", "bnb.nf4, optional): quantization method to use. Defaults to None.
- "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq": 4-bit quantization bitsandbytes
- "bnb.int8": 8-bit quantization bitsandbytes
- "gptq.int4": 4-bit quantization GPTQ
for more details see: https://github.com/facebookresearch/bitsandbytes, https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
strategy (str, optional): Fabric strategy setting. Defaults to "auto".
devices (int, optional): number of devices to be used. Defaults to 1.
precision (Optional[str], optional): fabic precision settings. Defaults to None.
"""
chptk_path: str = "saved_model/last-iter-015000-ckpt.pth"
tokenizer_path: str = "tokenizer_Llama-2-7b-chat-hf"
quantize: Optional[
Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]
] = None
strategy: str = "auto"
devices: int = 1
precision: Optional[str] = None
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None:
if devices > 1:
raise NotImplemented("Multi-GPU quantization is not supported yet.")
if quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[
precision
]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
if strategy == "fsdp":
strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, plugins=plugins)
fabric.launch()
tokenizer = Tokenizer(Path("tokenizer_Llama-2-7b-chat-hf"))
config = Config.from_name("pythia-160m")
fabric.print(f"Loading model from {chptk_path}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
model = GPT(config)
fabric.print(
f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
)
with fabric.init_tensor():
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
model = fabric.setup_module(model)
t0 = time.perf_counter()
load_checkpoint(fabric, model, chptk_path)
fabric.print(
f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
)
def generate_from_prompt(
prompt: str = "",
max_new_tokens: int = 500,
top_k: int = 200,
temperature: float = 0.8,
):
"""Generate text from a prompt using pre-trained model
Args:
prompt (str, optional): Prompt string to be used for generating samples. Defaults to "".
num_samples (int, optional): Number of samples to be generated. Defaults to 1.
max_new_tokens (int, optional): number of generation steps to take. Defaults to 500.
top_k (int, optional): top most preferable tokens to consider in the sampling process. Defaults to 200.
temperature (float, optional): Control randomness for sampelling process. Defaults to 0.8.
"""
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
num_samples: int = 1
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0
# for block in model.transformer.h:
# block.attn.kv_cache.reset_parameters()
pred = tokenizer.decode(y)
fabric.print(pred)
tokens_generated = y.size(0) - prompt_length
fabric.print(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
file=sys.stderr,
)
if fabric.device.type == "cuda":
fabric.print(
f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr
)
return pred
|