T2IPromptGenerator / model.py
tungdop2's picture
fix docker
b35fa97
import os
import torch
# from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ChallengePromptGenerator:
def __init__(
self,
model_local_dir="checkpoint-15000",
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.generator = AutoModelForCausalLM.from_pretrained(model_local_dir, device_map=self.device)
self.generator.to_bettertransformer()
self.tokenizer = AutoTokenizer.from_pretrained(model_local_dir)
def infer_prompt(
self,
prompts,
max_generation_length=77,
beam_size=1,
sampling_temperature=0.9,
sampling_topk=100,
sampling_topp=1
):
# Add bos
prompts = [f"{self.tokenizer.bos_token} {prompt}" for prompt in prompts]
# Prepare inputs
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=False
).to(self.device)
# Generate
outputs = self.generator.generate(
**inputs,
max_length=max_generation_length,
num_beams=beam_size,
temperature=sampling_temperature,
top_k=sampling_topk,
top_p=sampling_topp,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id
)
# Decode
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
outputs = []
for out in decoded_outputs:
if out[-1] != ".":
out = ".".join(out.split(".")[:-1]) + "."
outputs.append(out)
return outputs