Spaces:
Runtime error
Runtime error
from config import * | |
apiUrl = os.environ['apiUrl'] | |
uploadToken = os.environ['uploadToken'] | |
openId = os.environ['openId'] | |
apiKey = os.environ['apiKey'] | |
Regions = os.environ['Regions'] | |
tokenUrl = os.environ['tokenUrl'] | |
LimitTask = int(os.environ['LimitTask']) | |
proj_dir = os.path.dirname(os.path.abspath(__file__)) | |
data_dir = os.path.join(proj_dir, 'Datas') | |
tmpFolder = os.path.join(proj_dir, 'tmp') | |
os.makedirs(tmpFolder, exist_ok=True) | |
def load_pkl(path): | |
with open(path, 'rb') as f: | |
return pickle.load(f) | |
def save_pkl(data, path, reweite=False): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
if not os.path.exists(path) or reweite: # 不存在或者强制重写 | |
with open(path,'wb') as file: | |
pickle.dump(data, file, protocol=4) | |
return data | |
else: | |
load_data = load_pkl(path) | |
for k in data: | |
load_data[k] = data[k] | |
save_pkl(load_data, path, reweite=True) | |
return load_data | |
def checkToken(token): | |
''' | |
输入token,检查token是否合法 | |
''' | |
return False | |
params = {'token':str(token)} | |
session = requests.session() | |
ret = requests.post(f"{tokenUrl}", data=json.dumps(params)) | |
print(ret) | |
res = False | |
if ret.status_code==200: | |
if 'res' in ret.json(): | |
res = (ret.json()['res'])>0 | |
else: | |
print(ret.json(), ret.status_code, 'call token failed') | |
return res | |
class UserRecorder(object): | |
def __init__(self, ): | |
super(UserRecorder, self).__init__() | |
record_dir = os.path.join(data_dir, f'UserRecord_{taskType}') | |
self.ip_dir = os.path.join(record_dir, 'Ips') | |
self.token_dir = os.path.join(record_dir, 'Tokens') | |
os.makedirs(self.ip_dir, exist_ok=True) | |
os.makedirs(self.token_dir, exist_ok=True) | |
def save_record(self, taskRes, ip="", token=""): | |
if len(ip)==0 and len(token)==0: return | |
if len(token)==0: # token优先 | |
record_path = os.path.join(self.ip_dir, f'{ip}.pkl') | |
else: | |
record_path = os.path.join(self.token_dir, f'{token}.pkl') | |
taskId = taskRes['id'] | |
status = taskRes['status'] | |
if 'output' in taskRes: | |
input1 = taskRes['output']['job_results']['input1'] | |
output1 = taskRes['output']['job_results']['output1'] | |
else: | |
input1, output1 = None, None | |
data = OrderedDict() | |
data[taskId] = {'input1':input1, 'output1':output1, 'status':status, } | |
save_data = save_pkl(data, record_path, reweite=False) | |
return save_data | |
def check_record(self, ip="", token=""): | |
if len(token)>0: | |
token_valid = checkToken(token) | |
if token_valid: | |
return True, "" | |
else: | |
return False, "token is invalid" | |
else: | |
_, total_n, _ = self.get_record(ip=ip, token=token) | |
if total_n>=LimitTask: | |
return False, no_more_attempts | |
else: | |
return True, "" | |
def get_record(self, ip="", token=""): | |
if len(ip)==0 and len(token)==0: return | |
if len(token)==0: | |
identity = ip | |
record_path = os.path.join(self.ip_dir, f'{ip}.pkl') | |
else: | |
identity = token | |
record_path = os.path.join(self.token_dir, f'{token}.pkl') | |
if os.path.exists(record_path): | |
record_data = load_pkl(record_path) | |
else: | |
record_data = {} | |
total_n = len(record_data) | |
success_n, fail_n, process_n = 0, 0, 0 | |
shows = [None]*6 | |
show_i = 0 | |
print(record_data) | |
for key in reversed(record_data): | |
status = record_data[key]['status'] | |
if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]: | |
fail_n += 1 | |
elif status in ['COMPLETED', ]: | |
success_n += 1 | |
if record_data[key]['input1'] is not None: | |
input1 = record_data[key]['input1'] | |
output1 = record_data[key]['output1'] | |
if show_i<=2: | |
shows[show_i*2] = f"<img src=\"{input1}\" >" | |
shows[show_i*2+1] = f"<img src=\"{output1}\" >" | |
show_i += 1 | |
elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]: | |
process_n += 1 | |
msg = f"Dear {identity}, You have {total_n} tasks, {success_n} successed, {fail_n} failed, {process_n} processing, " | |
return shows, total_n, msg | |
def get_temps_examples(taskType): | |
temp_dir = os.path.join(data_dir, f'task{taskType}/temps') | |
examples = [] | |
if not os.path.exists(temp_dir): return [] | |
files = [f for f in sorted(os.listdir(temp_dir)) if '.' in f] | |
for f in files: | |
temp_name = f.split(".")[0] | |
if len(temp_name)==0: continue | |
temp_path = os.path.join(temp_dir, f) | |
examples.append([temp_path]) | |
examples = examples[::-1] | |
return examples | |
def get_user_examples(taskType): | |
user_dir = os.path.join(data_dir, f'task{taskType}/srcs') | |
examples = [] | |
if not os.path.exists(user_dir): return [] | |
files = [f for f in sorted(os.listdir(user_dir)) if '.' in f] | |
for f in files: | |
user_id = f.split(".")[0] | |
if len(user_id)==0: continue | |
user_path = os.path.join(user_dir, f) | |
examples.append([user_path]) | |
return examples | |
def get_showcase_examples(taskType): | |
examples = [] | |
if taskType=="3": | |
examples=[ | |
["task3/temps/flow-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_flower-water.jpg"], | |
["task3/temps/mountain-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_mountain-water.jpg"], | |
["task3/temps/rock-on-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_rock-on-water.jpg"], | |
] | |
elif taskType=="4": | |
examples=[ | |
["task4/temps/Vivienne.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_vivienne.jpg"], | |
["task4/temps/Bella.jpg", "task4/srcs/src04.jpg", "task4/showcases/src04_balle.jpg"], | |
["task4/temps/Nia.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_nia.jpg"], | |
["task4/temps/Leo.jpg", "task4/srcs/src03.jpg", "task4/showcases/src03_male.jpg"], | |
] | |
elif taskType=="6": | |
examples=[ | |
["task6/temps/niantu.jpg", "task6/srcs/src01.jpg", "task6/showcases/src01_niantu.jpg"], | |
["task6/temps/3d-shouban.jpg", "task6/srcs/src02.jpg", "task6/showcases/src02_shouban.jpg"], | |
] | |
elif taskType=="5": | |
examples=[ | |
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_street.jpg"], | |
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_walk.jpg"], | |
] | |
elif taskType=="1": | |
examples=[ | |
["task1/temps/caption.jpg", "task1/srcs/src01.jpg", "task1/showcases/src01_seg.png"], | |
] | |
for i in range(len(examples)): | |
for j in range(len(examples[i])): | |
examples[i][j] = os.path.join(data_dir, examples[i][j]) | |
assert os.path.exists(examples[i][j]), examples[i][j] | |
return examples | |
def get_result_example(cloth_id, pose_id): | |
result_dir = os.path.join(data_dir, 'ResultImgs') | |
res_path = os.path.join(result_dir, f"{cloth_id}_{pose_id}.jpg") | |
return res_path | |
def upload_user_img(clientIp, img): | |
timeId = int( str(time.time()).replace(".", "") )+random.randint(1000, 9999) | |
fileName = clientIp.replace(".", "")+str(timeId)+".jpg" | |
local_path = os.path.join(tmpFolder, fileName) | |
cv2.imwrite(os.path.join(tmpFolder, fileName), img[:,:,::-1]) | |
params = {'token':uploadToken, 'input1':fileName, 'input2':''} | |
session = requests.session() | |
ret = requests.post(f"{apiUrl}/upload", data=json.dumps(params)) | |
res = "" | |
if ret.status_code==200: | |
if 'upload1' in ret.json(): | |
upload_url = ret.json()['upload1'] | |
with open(local_path, 'rb') as file: | |
response = requests.put(upload_url, data=file) | |
if response.status_code == 200: | |
res = upload_url | |
else: | |
print(ret.json(), ret.status_code, 'call upload failed') | |
if os.path.exists(local_path): | |
os.remove(local_path) | |
return res | |
def publicSelfitTask(image, temp_image, caption_text): | |
if taskType in ['5']: # 无模板 | |
temp_name = '' | |
else: | |
temp_name = os.path.basename(temp_image).split('.')[0] | |
params = {'openId':openId, 'apiKey':apiKey, 'image':image, 'mask':"", | |
"image_type":"2", "task_type":taskType, 'param1':temp_name, | |
'param2':str(caption_text), 'param3':"1", 'param4':"", 'param5':""} | |
session = requests.session() | |
ret = requests.post(f"{apiUrl}/public", data=json.dumps(params)) | |
print(ret) | |
if ret.status_code==200: | |
if 'id' in ret.json(): | |
print(ret.json()) | |
return ret.json()['id'] | |
else: | |
print(ret.json(), ret.status_code, 'call public failed') | |
def getTaskRes(taskId): | |
params = {'id':taskId} | |
session = requests.session() | |
ret = requests.post(f"{apiUrl}/status", data=json.dumps(params)) | |
print(ret) | |
if ret.status_code==200: | |
if 'status' in ret.json(): | |
print(ret.json()) | |
return ret.json() | |
else: | |
print(ret.json(), ret.status_code, 'call status failed') | |
return None | |
def check_region(ip): | |
session = requests.session() | |
ret = requests.get(f"https://webapi-pc.meitu.com/common/ip_location?ip={ip}") | |
for k in ret.json()['data']: | |
nat = ret.json()['data'][k]['nation'] | |
if nat in Regions: | |
print(nat, 'invalid') | |
return False | |
else: | |
print(nat, 'valid') | |
return True | |
def check_region_warp(ip): | |
try: | |
return check_region(ip) | |
except Exception as e: | |
print(e) | |
return True | |