|
import torch |
|
import argparse |
|
import os |
|
import random |
|
import os |
|
import json |
|
import random, os |
|
import numpy as np |
|
import torch |
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
from transformers import TextStreamer, GenerationConfig |
|
|
|
|
|
class LocalStoppingCriteria(StoppingCriteria): |
|
|
|
def __init__(self, tokenizer, stop_words=[]): |
|
super().__init__() |
|
|
|
stops = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for |
|
stop_word in stop_words] |
|
print('stop_words', stop_words) |
|
print('stop_words_ids', stops) |
|
self.stop_words = stop_words |
|
self.stops = [stop.cuda() for stop in stops] |
|
self.tokenizer = tokenizer |
|
|
|
def _compare_token(self, input_ids): |
|
for stop in self.stops: |
|
if len(stop.size()) != 1: |
|
continue |
|
stop_len = len(stop) |
|
if torch.all((stop == input_ids[0][-stop_len:])).item(): |
|
return True |
|
|
|
return False |
|
|
|
def _compare_decode(self, input_ids): |
|
input_str = self.tokenizer.decode(input_ids[0]) |
|
for stop_word in self.stop_words: |
|
if input_str.endswith(stop_word): |
|
return True |
|
return False |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
return self._compare_decode(input_ids) |
|
|
|
|
|
def seed_everything(seed: int): |
|
|
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def generation(model, tokenizer, x, max_new_tokens=1024): |
|
|
|
stopping_criteria = StoppingCriteriaList( |
|
[LocalStoppingCriteria(tokenizer=tokenizer, stop_words=[tokenizer.eos_token])]) |
|
streamer = TextStreamer(tokenizer) |
|
|
|
|
|
generation_config = GenerationConfig( |
|
temperature=1.0, |
|
top_p=0.8, |
|
top_k=100, |
|
max_new_tokens=max_new_tokens, |
|
early_stopping=True, |
|
do_sample=True, |
|
) |
|
gened = model.generate( |
|
**tokenizer( |
|
x, |
|
return_tensors='pt', |
|
return_token_type_ids=False |
|
).to('cuda'), |
|
generation_config=generation_config, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
stopping_criteria=stopping_criteria, |
|
streamer=streamer, |
|
) |
|
response = tokenizer.decode(gened[0]) |
|
only_gen_text = response.split(x) |
|
if len(only_gen_text) == 2: |
|
response = only_gen_text[-1] |
|
response = response.replace(tokenizer.eos_token, '') |
|
return response |
|
|