Spaces:
Runtime error
Runtime error
import os | |
import json | |
import requests | |
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer | |
from urllib.parse import parse_qs, urlparse | |
from inference import infer_t5 | |
from dataset import query_emotion | |
# https://huggingface.co/settings/tokens | |
# https://huggingface.co/spaces/{username}/{space}/settings | |
API_TOKEN = os.getenv("BIG_GAN_TOKEN") | |
class RequestHandler(SimpleHTTPRequestHandler): | |
def do_GET(self): | |
if self.path == "/": | |
self.path = "index.html" | |
return SimpleHTTPRequestHandler.do_GET(self) | |
if self.path.startswith("/infer_biggan"): | |
url = urlparse(self.path) | |
query = parse_qs(url.query) | |
input = query.get("input", None)[0] | |
output = requests.request( | |
"POST", | |
"https://api-inference.huggingface.co/models/osanseviero/BigGAN-deep-128", | |
headers={"Authorization": f"Bearer {API_TOKEN}"}, | |
data=json.dumps(input), | |
) | |
self.send_response(200) | |
self.send_header("Content-Type", "application/json") | |
self.end_headers() | |
self.wfile.write(output.content) | |
return SimpleHTTPRequestHandler | |
elif self.path.startswith("/infer_t5"): | |
url = urlparse(self.path) | |
query = parse_qs(url.query) | |
input = query.get("input", None)[0] | |
output = infer_t5(input) | |
self.send_response(200) | |
self.send_header("Content-Type", "application/json") | |
self.end_headers() | |
self.wfile.write(json.dumps({"output": output}).encode("utf-8")) | |
return SimpleHTTPRequestHandler | |
elif self.path.startswith("/query_emotion"): | |
url = urlparse(self.path) | |
query = parse_qs(url.query) | |
start = int(query.get("start", None)[0]) | |
end = int(query.get("end", None)[0]) | |
output = query_emotion(start, end) | |
self.send_response(200) | |
self.send_header("Content-Type", "application/json") | |
self.end_headers() | |
self.wfile.write(json.dumps({"output": output}).encode("utf-8")) | |
return SimpleHTTPRequestHandler | |
else: | |
return SimpleHTTPRequestHandler.do_GET(self) | |
server = ThreadingHTTPServer(("", 7860), RequestHandler) | |
server.serve_forever() | |