File size: 5,971 Bytes
8dad166 629f62e 8dad166 629f62e 8dad166 ba9988f 8dad166 ba9988f 8dad166 629f62e 8dad166 629f62e 8dad166 629f62e 8dad166 629f62e 8dad166 629f62e 8dad166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, temperature_a=1.0,token_ban=[], token_stop=[], chunk_len=256):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.typical_p = typical_p
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.temperature_a = temperature_a
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
class PIPELINE():
def __init__(self, model, WORD_NAME):
self.model = model
if WORD_NAME == 'cl100k_base':
import tiktoken
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
elif WORD_NAME == 'rwkv_vocab_v20230424':
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(f'./{WORD_NAME}.txt')
else:
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def encode(self, x):
# if 'tiktoken' in str(type(self.tokenizer)):
# return self.tokenizer.encode(x)
# else:
# return self.tokenizer.encode(x).ids
encoded = self.tokenizer.encode(x)
if hasattr(encoded,"ids"):
encoded = encoded.ids
return encoded
def decode(self, x):
return self.tokenizer.decode(x)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1,temperature_a=1.0):
if temperature_a != 1.0:
logits = logits / temperature_a
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if typical_p<1:
entropy = torch.nansum(-torch.log(probs) * probs, dim=-1, keepdim=True)
typical_scores = torch.abs(logits - entropy)
typical_sorted_ids = torch.argsort(typical_scores)
sorted_typical_scores = typical_scores[typical_sorted_ids]
typical_sorted_probs = probs[typical_sorted_ids]
cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy()
typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs >= typical_p)])
if probs.device == torch.device('cpu'):
probs = probs.numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
if top_p < 1:
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if typical_p<1:
probs[typical_scores > typical_cutoff] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
if top_p < 1:
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if typical_p<1:
probs[typical_scores > typical_cutoff] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
tokens = self.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
out, state = self.model.forward(tokens[:args.chunk_len], state)
tokens = tokens[args.chunk_len:]
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p,temperature_a=args.temperature_a)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str |