Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from threading import Thread | |
import falcon | |
from falcon.http_status import HTTPStatus | |
import json | |
import requests | |
import time | |
from Model import generate_completion | |
import sys | |
class AutoComplete(object): | |
def on_post(self, req, resp, single_endpoint=True, x=None, y=None): | |
json_data = json.loads(req.bounded_stream.read()) | |
resp.status = falcon.HTTP_200 | |
start = time.time() | |
try: | |
context = json_data["context"].rstrip() | |
except KeyError: | |
resp.body = "The context field is required" | |
resp.status = falcon.HTTP_422 | |
return | |
try: | |
n_samples = json_data['samples'] | |
except KeyError: | |
n_samples = 3 | |
try: | |
length = json_data['gen_length'] | |
except KeyError: | |
length = 20 | |
try: | |
max_time = json_data['max_time'] | |
except KeyError: | |
max_time = -1 | |
try: | |
model_name = json_data['model_size'] | |
except KeyError: | |
model_name = "small" | |
try: | |
temperature = json_data['temperature'] | |
except KeyError: | |
temperature = 0.7 | |
try: | |
max_tokens = json_data['max_tokens'] | |
except KeyError: | |
max_tokens = 256 | |
try: | |
top_p = json_data['top_p'] | |
except KeyError: | |
top_p = 0.95 | |
try: | |
top_k = json_data['top_k'] | |
except KeyError: | |
top_k = 40 | |
# CTRL | |
try: | |
repetition_penalty = json_data['repetition_penalty'] | |
except KeyError: | |
repetition_penalty = 0.02 | |
# PPLM | |
try: | |
stepsize = json_data['step_size'] | |
except KeyError: | |
stepsize = 0.02 | |
try: | |
gm_scale = json_data['gm_scale'] | |
except KeyError: | |
gm_scale = None | |
try: | |
kl_scale = json_data['kl_scale'] | |
except KeyError: | |
kl_scale = None | |
try: | |
num_iterations = json_data['num_iterations'] | |
except KeyError: | |
num_iterations = None | |
try: | |
use_sampling = json_data['use_sampling'] | |
except KeyError: | |
use_sampling = None | |
try: | |
bag_of_words_or_discrim = json_data['bow_or_discrim'] | |
except KeyError: | |
bag_of_words_or_discrim = "kitchen" | |
print(json_data) | |
sentences = generate_completion( | |
context, | |
length=length, | |
max_time=max_time, | |
model_name=model_name, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
# CTRL | |
repetition_penalty=repetition_penalty, | |
# PPLM | |
stepsize=stepsize, | |
bag_of_words_or_discrim=bag_of_words_or_discrim, | |
gm_scale=gm_scale, | |
kl_scale=kl_scale, | |
num_iterations=num_iterations, | |
use_sampling=use_sampling | |
) | |
resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start}) | |
resp.status = falcon.HTTP_200 | |
sys.stdout.flush() | |
class Request(Thread): | |
def __init__(self, end_point, data): | |
Thread.__init__(self) | |
self.end_point = end_point | |
self.data = data | |
self.ret = None | |
def run(self): | |
print("Requesting with url", self.end_point) | |
self.ret = requests.post(url=self.end_point, json=self.data) | |
def join(self): | |
Thread.join(self) | |
return self.ret.text | |
class HandleCORS(object): | |
def process_request(self, req, resp): | |
resp.set_header('Access-Control-Allow-Origin', '*') | |
resp.set_header('Access-Control-Allow-Methods', '*') | |
resp.set_header('Access-Control-Allow-Headers', '*') | |
if req.method == 'OPTIONS': | |
raise HTTPStatus(falcon.HTTP_200, body='\n') | |
autocomplete = AutoComplete() | |
app = falcon.API(middleware=[HandleCORS()]) | |
app.add_route('/autocomplete', autocomplete) | |
app.add_route('/autocomplete/{x}', autocomplete) | |
app.add_route('/autocomplete/{x}/{y}', autocomplete) | |
application = app | |