ChatGal / utils.py
wanicca's picture
Add world support
ba9988f
raw
history blame contribute delete
No virus
5.97 kB
import json, time, random, os
import numpy as np
import torch
from torch.nn import functional as F
class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, temperature_a=1.0,token_ban=[], token_stop=[], chunk_len=256):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.typical_p = typical_p
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.temperature_a = temperature_a
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
class PIPELINE():
def __init__(self, model, WORD_NAME):
self.model = model
if WORD_NAME == 'cl100k_base':
import tiktoken
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
elif WORD_NAME == 'rwkv_vocab_v20230424':
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(f'./{WORD_NAME}.txt')
else:
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def encode(self, x):
# if 'tiktoken' in str(type(self.tokenizer)):
# return self.tokenizer.encode(x)
# else:
# return self.tokenizer.encode(x).ids
encoded = self.tokenizer.encode(x)
if hasattr(encoded,"ids"):
encoded = encoded.ids
return encoded
def decode(self, x):
return self.tokenizer.decode(x)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1,temperature_a=1.0):
if temperature_a != 1.0:
logits = logits / temperature_a
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if typical_p<1:
entropy = torch.nansum(-torch.log(probs) * probs, dim=-1, keepdim=True)
typical_scores = torch.abs(logits - entropy)
typical_sorted_ids = torch.argsort(typical_scores)
sorted_typical_scores = typical_scores[typical_sorted_ids]
typical_sorted_probs = probs[typical_sorted_ids]
cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy()
typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs >= typical_p)])
if probs.device == torch.device('cpu'):
probs = probs.numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
if top_p < 1:
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if typical_p<1:
probs[typical_scores > typical_cutoff] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
if top_p < 1:
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if typical_p<1:
probs[typical_scores > typical_cutoff] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
tokens = self.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
out, state = self.model.forward(tokens[:args.chunk_len], state)
tokens = tokens[args.chunk_len:]
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p,temperature_a=args.temperature_a)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str