Quiet-Star-Custom / inference.py
Crystalcareai's picture
Update inference.py
d3c4ad0 verified
raw
history blame contribute delete
No virus
657 Bytes
import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
model_path = "cognitivecomputations/Quiet-STaR-Base"
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = "Hello my name is"
tokens = tokenizer(
prompt,
return_tensors='pt'
).input_ids.cuda()
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512,
)