import json, time, random, os import numpy as np import torch from torch.nn import functional as F class PIPELINE_ARGS(): def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, temperature_a=1.0,token_ban=[], token_stop=[], chunk_len=256): self.temperature = temperature self.top_p = top_p self.top_k = top_k self.typical_p = typical_p self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3) self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3) self.temperature_a = temperature_a self.token_ban = token_ban # ban the generation of some tokens self.token_stop = token_stop # stop generation whenever you see any token here self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower) class PIPELINE(): def __init__(self, model, WORD_NAME): self.model = model if WORD_NAME == 'cl100k_base': import tiktoken self.tokenizer = tiktoken.get_encoding(WORD_NAME) elif WORD_NAME == 'rwkv_vocab_v20230424': from rwkv_tokenizer import TRIE_TOKENIZER self.tokenizer = TRIE_TOKENIZER(f'./{WORD_NAME}.txt') else: from tokenizers import Tokenizer self.tokenizer = Tokenizer.from_file(WORD_NAME) def refine_context(self, context): context = context.strip().split('\n') for c in range(len(context)): context[c] = context[c].strip().strip('\u3000').strip('\r') context = list(filter(lambda c: c != '', context)) context = '\n' + ('\n'.join(context)).strip() if context == '': context = '\n' return context def encode(self, x): # if 'tiktoken' in str(type(self.tokenizer)): # return self.tokenizer.encode(x) # else: # return self.tokenizer.encode(x).ids encoded = self.tokenizer.encode(x) if hasattr(encoded,"ids"): encoded = encoded.ids return encoded def decode(self, x): return self.tokenizer.decode(x) def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1,temperature_a=1.0): if temperature_a != 1.0: logits = logits / temperature_a probs = F.softmax(logits.float(), dim=-1) top_k = int(top_k) if typical_p<1: entropy = torch.nansum(-torch.log(probs) * probs, dim=-1, keepdim=True) typical_scores = torch.abs(logits - entropy) typical_sorted_ids = torch.argsort(typical_scores) sorted_typical_scores = typical_scores[typical_sorted_ids] typical_sorted_probs = probs[typical_sorted_ids] cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy() typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs >= typical_p)]) if probs.device == torch.device('cpu'): probs = probs.numpy() sorted_ids = np.argsort(probs) sorted_probs = probs[sorted_ids][::-1] cumulative_probs = np.cumsum(sorted_probs) cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) if top_p < 1: probs[probs < cutoff] = 0 if top_k < len(probs) and top_k > 0: probs[sorted_ids[:-top_k]] = 0 if typical_p<1: probs[typical_scores > typical_cutoff] = 0 if temperature != 1.0: probs = probs ** (1.0 / temperature) probs = probs / np.sum(probs) out = np.random.choice(a=len(probs), p=probs) return int(out) else: sorted_ids = torch.argsort(probs) sorted_probs = probs[sorted_ids] sorted_probs = torch.flip(sorted_probs, dims=(0,)) cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) if top_p < 1: probs[probs < cutoff] = 0 if top_k < len(probs) and top_k > 0: probs[sorted_ids[:-top_k]] = 0 if typical_p<1: probs[typical_scores > typical_cutoff] = 0 if temperature != 1.0: probs = probs ** (1.0 / temperature) out = torch.multinomial(probs, num_samples=1)[0] return int(out) def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None): all_tokens = [] out_last = 0 out_str = '' occurrence = {} for i in range(token_count): # forward & adjust prob. tokens = self.encode(ctx) if i == 0 else [token] while len(tokens) > 0: out, state = self.model.forward(tokens[:args.chunk_len], state) tokens = tokens[args.chunk_len:] for n in args.token_ban: out[n] = -float('inf') for n in occurrence: out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) # sampler token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p,temperature_a=args.temperature_a) if token in args.token_stop: break all_tokens += [token] if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 # output tmp = self.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: # is valid utf-8 string? if callback: callback(tmp) out_str += tmp out_last = i + 1 return out_str