import os import torch # from vllm import LLM, SamplingParams from transformers import AutoTokenizer, AutoModelForCausalLM import logging # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class ChallengePromptGenerator: def __init__( self, model_local_dir="checkpoint-15000", ): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.generator = AutoModelForCausalLM.from_pretrained(model_local_dir, device_map=self.device) self.generator.to_bettertransformer() self.tokenizer = AutoTokenizer.from_pretrained(model_local_dir) def infer_prompt( self, prompts, max_generation_length=77, beam_size=1, sampling_temperature=0.9, sampling_topk=100, sampling_topp=1 ): # Add bos prompts = [f"{self.tokenizer.bos_token} {prompt}" for prompt in prompts] # Prepare inputs inputs = self.tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=256, add_special_tokens=False ).to(self.device) # Generate outputs = self.generator.generate( **inputs, max_length=max_generation_length, num_beams=beam_size, temperature=sampling_temperature, top_k=sampling_topk, top_p=sampling_topp, do_sample=True, pad_token_id=self.tokenizer.pad_token_id ) # Decode decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = [] for out in decoded_outputs: if out[-1] != ".": out = ".".join(out.split(".")[:-1]) + "." outputs.append(out) return outputs