File size: 4,786 Bytes
f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 3f1b7f0 f289b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from ast import Dict
import logging
import logging.handlers
import os
import sys
import base64
from PIL import Image
from io import BytesIO
import json
import requests
from constants import LOGDIR
import datetime
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
moderation_msg = (
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
)
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="D", utc=True
)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def get_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def data_wrapper(data):
if isinstance(data, bytes):
return data
elif isinstance(data, Image.Image):
buffered = BytesIO()
data.save(buffered, format="PNG")
return buffered.getvalue()
elif isinstance(data, str):
return data.encode()
elif isinstance(data, Dict):
return json.dumps(data).encode()
else:
raise ValueError(f"Unsupported data type: {type(data)}")
|