import torch from typing import Dict, List, Any from transformers import AutoTokenizer, CodeGenForCausalLM, pipeline import argparse import json import os from pathlib import Path import random from time import time import torch # check for GPU device = 0 if torch.cuda.is_available() else "cpu" seed = 16 max_length=2048 max_gen_length=128 top_p=0.95 temp=0.2 num_return_sequences=1 pad_token_id=50256 prefix = "# Import libraries.\n\nimport numpy as np\n\n" os.environ["TOKENIZERS_PARALLELISM"] = "false" random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def truncate(completion): import re def find_re(string, pattern, start_pos): m = pattern.search(string, start_pos) return m.start() if m else -1 terminals = [re.compile(r, re.MULTILINE) for r in ['^#', re.escape('<|endoftext|>'), "^'''", '^"""', '\n\n\n']] prints = list(re.finditer('^print', completion, re.MULTILINE)) if len(prints) > 1: completion = completion[:prints[1].start()] defs = list(re.finditer('^def', completion, re.MULTILINE)) if len(defs) > 1: completion = completion[:defs[1].start()] start_pos = 0 terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1] if len(terminals_pos) > 0: return completion[:min(terminals_pos)] else: return completion class EndpointHandler: def __init__(self, path=""): # load the model self.tokenizer = AutoTokenizer.from_pretrained(path) if torch.cuda.is_available(): self.model = CodeGenForCausalLM.from_pretrained(path, torch_dtype=torch.float16) else: self.model = CodeGenForCausalLM.from_pretrained(path) # create inference pipeline self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_size = "left" self.model.to(device) def __call__(self, data): prompt = data.pop("inputs", data) function = data.pop("parameters", False) print("prompt is: ", prompt) print("function is: ", function) prompt = prefix + prompt bad_words = [] force_words = [] print_ = True if function: force_words.append(self.tokenizer(['def'], add_special_tokens=False).input_ids) else: bad_words.append(self.tokenizer(['def'], add_special_tokens=False).input_ids) if "print" in prompt[-1].split("#")[-1].lower(): force_words.append(self.tokenizer(['print'], add_special_tokens=False).input_ids) print_ = True else: bad_words.append(self.tokenizer(['print'], add_special_tokens=False).input_ids) print_ = False force_words = [item for sublist in force_words for item in sublist] bad_words = [item for sublist in bad_words for item in sublist] input_ids = self.tokenizer( prompt, truncation=True, padding=True, return_tensors="pt", ).input_ids input_ids_len = input_ids.shape[1] assert input_ids_len < max_length with torch.no_grad(): input_ids = input_ids.to(device) if len(force_words) == 0 and len(bad_words)==0: tokens = self.model.generate( input_ids, do_sample=True, num_return_sequences=num_return_sequences, temperature=temp, max_length=input_ids_len + max_gen_length, top_p=top_p, pad_token_id=pad_token_id, use_cache=True, ) elif len(force_words) == 0 and len(bad_words)!=0: tokens = self.model.generate( input_ids, do_sample=True, num_return_sequences=num_return_sequences, temperature=temp, max_length=input_ids_len + max_gen_length, top_p=top_p, pad_token_id=pad_token_id, use_cache=True, bad_words_ids= bad_words ) elif len(force_words)!=0 and len(bad_words) ==0: tokens = self.model.generate( input_ids, do_sample=False, num_return_sequences=num_return_sequences, temperature=temp, max_length=input_ids_len + max_gen_length, top_p=top_p, pad_token_id=pad_token_id, use_cache=True, num_beams=5, force_words_ids=force_words ) elif len(force_words)!=0 and len(bad_words) !=0: tokens = self.model.generate( input_ids, do_sample=False, num_return_sequences=num_return_sequences, temperature=temp, max_length=input_ids_len + max_gen_length, top_p=top_p, pad_token_id=pad_token_id, use_cache=True, num_beams=5, force_words_ids=force_words, bad_words_ids= bad_words ) text = self.tokenizer.batch_decode(tokens[:, input_ids_len:, ...]) desired = "THE PROMPT IS:" + "\n=======================" + prompt + "\n FUNCTION IS" + function + "\n=======================" return truncate(text[0]) + desired