|
|
|
from config import * |
|
|
|
|
|
apiUrl = os.environ['apiUrl'] |
|
uploadToken = os.environ['uploadToken'] |
|
openId = os.environ['openId'] |
|
apiKey = os.environ['apiKey'] |
|
Regions = os.environ['Regions'] |
|
tokenUrl = os.environ['tokenUrl'] |
|
LimitTask = int(os.environ['LimitTask']) |
|
|
|
|
|
proj_dir = os.path.dirname(os.path.abspath(__file__)) |
|
data_dir = os.path.join(proj_dir, 'Datas') |
|
tmpFolder = os.path.join(proj_dir, 'tmp') |
|
os.makedirs(tmpFolder, exist_ok=True) |
|
|
|
|
|
|
|
def load_pkl(path): |
|
with open(path, 'rb') as f: |
|
return pickle.load(f) |
|
|
|
def save_pkl(data, path, reweite=False): |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
if not os.path.exists(path) or reweite: |
|
with open(path,'wb') as file: |
|
pickle.dump(data, file, protocol=4) |
|
return data |
|
else: |
|
load_data = load_pkl(path) |
|
for k in data: |
|
load_data[k] = data[k] |
|
save_pkl(load_data, path, reweite=True) |
|
return load_data |
|
|
|
def checkToken(token): |
|
''' |
|
输入token,检查token是否合法 |
|
''' |
|
|
|
params = {'trans_type':'hf_space', 'cost_credits':1} |
|
headers = {"Authorization": f"Bearer {token}"} |
|
session = requests.session() |
|
ret = requests.post(f"{tokenUrl}", data=json.dumps(params), headers=headers) |
|
print(ret) |
|
res = False |
|
if ret.status_code==200: |
|
if 'left_credits' in ret.json(): |
|
res = (ret.json()['left_credits'])>0 |
|
else: |
|
print(ret.json(), ret.status_code, 'call token failed') |
|
return res |
|
|
|
class UserRecorder(object): |
|
|
|
def __init__(self, ): |
|
super(UserRecorder, self).__init__() |
|
record_dir = os.path.join(data_dir, f'UserRecord_{taskType}') |
|
self.ip_dir = os.path.join(record_dir, 'Ips') |
|
self.token_dir = os.path.join(record_dir, 'Tokens') |
|
os.makedirs(self.ip_dir, exist_ok=True) |
|
os.makedirs(self.token_dir, exist_ok=True) |
|
|
|
def save_record(self, taskRes, ip="", token=""): |
|
if len(ip)==0 and len(token)==0: return |
|
if len(token)==0: |
|
record_path = os.path.join(self.ip_dir, f'{ip}.pkl') |
|
else: |
|
record_path = os.path.join(self.token_dir, f'{token}.pkl') |
|
taskId = taskRes['id'] |
|
status = taskRes['status'] |
|
if 'output' in taskRes: |
|
input1 = taskRes['output']['job_results']['input1'] |
|
output1 = taskRes['output']['job_results']['output1'] |
|
else: |
|
input1, output1 = None, None |
|
data = OrderedDict() |
|
data[taskId] = {'input1':input1, 'output1':output1, 'status':status, } |
|
save_data = save_pkl(data, record_path, reweite=False) |
|
return save_data |
|
|
|
def check_record(self, ip="", token=""): |
|
if len(token)>0: |
|
token_valid = checkToken(token) |
|
if token_valid: |
|
return True, "" |
|
else: |
|
return False, "faild, api key is invalid" |
|
else: |
|
_, total_n, _ = self.get_record(ip=ip, token=token) |
|
if total_n>=LimitTask: |
|
return False, no_more_attempts |
|
else: |
|
return True, "" |
|
|
|
def get_record(self, ip="", token=""): |
|
if len(ip)==0 and len(token)==0: return |
|
if len(token)==0: |
|
identity = ip |
|
record_path = os.path.join(self.ip_dir, f'{ip}.pkl') |
|
else: |
|
identity = token |
|
record_path = os.path.join(self.token_dir, f'{token}.pkl') |
|
if os.path.exists(record_path): |
|
record_data = load_pkl(record_path) |
|
else: |
|
record_data = {} |
|
total_n = len(record_data) |
|
success_n, fail_n, process_n = 0, 0, 0 |
|
shows = [None]*6 |
|
show_i = 0 |
|
for key in reversed(record_data): |
|
status = record_data[key]['status'] |
|
if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]: |
|
fail_n += 1 |
|
|
|
elif status in ['COMPLETED', ]: |
|
success_n += 1 |
|
if record_data[key]['input1'] is not None: |
|
input1 = record_data[key]['input1'] |
|
output1 = record_data[key]['output1'] |
|
if show_i<=2: |
|
shows[show_i*2] = f"<img src=\"{input1}\" >" |
|
shows[show_i*2+1] = f"<img src=\"{output1}\" >" |
|
show_i += 1 |
|
elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]: |
|
process_n += 1 |
|
|
|
msg = f"Dear {identity}, You have {total_n} tasks, {success_n} successed, {fail_n} failed, {process_n} processing, " |
|
|
|
return shows, total_n, msg |
|
|
|
|
|
def get_temps_examples(taskType): |
|
temp_dir = os.path.join(data_dir, f'task{taskType}/temps') |
|
examples = [] |
|
if not os.path.exists(temp_dir): return [] |
|
files = [f for f in sorted(os.listdir(temp_dir)) if '.' in f] |
|
for f in files: |
|
temp_name = f.split(".")[0] |
|
if len(temp_name)==0: continue |
|
temp_path = os.path.join(temp_dir, f) |
|
examples.append([temp_path]) |
|
examples = examples[::-1] |
|
return examples |
|
|
|
def get_user_examples(taskType): |
|
user_dir = os.path.join(data_dir, f'task{taskType}/srcs') |
|
examples = [] |
|
if not os.path.exists(user_dir): return [] |
|
files = [f for f in sorted(os.listdir(user_dir)) if '.' in f] |
|
for f in files: |
|
user_id = f.split(".")[0] |
|
if len(user_id)==0: continue |
|
user_path = os.path.join(user_dir, f) |
|
examples.append([user_path]) |
|
return examples |
|
|
|
def get_showcase_examples(taskType): |
|
examples = [] |
|
if taskType=="3": |
|
examples=[ |
|
["task3/temps/flow-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_flower-water.jpg"], |
|
["task3/temps/mountain-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_mountain-water.jpg"], |
|
["task3/temps/rock-on-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_rock-on-water.jpg"], |
|
] |
|
elif taskType=="4": |
|
examples=[ |
|
["task4/temps/Vivienne.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_vivienne.jpg"], |
|
["task4/temps/Bella.jpg", "task4/srcs/src04.jpg", "task4/showcases/src04_balle.jpg"], |
|
["task4/temps/Nia.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_nia.jpg"], |
|
["task4/temps/Leo.jpg", "task4/srcs/src03.jpg", "task4/showcases/src03_male.jpg"], |
|
] |
|
elif taskType=="6": |
|
examples=[ |
|
["task6/temps/niantu.jpg", "task6/srcs/src01.jpg", "task6/showcases/src01_niantu.jpg"], |
|
["task6/temps/3d-shouban.jpg", "task6/srcs/src02.jpg", "task6/showcases/src02_shouban.jpg"], |
|
] |
|
|
|
elif taskType=="5": |
|
examples=[ |
|
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_street.jpg"], |
|
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_walk.jpg"], |
|
] |
|
|
|
elif taskType=="1": |
|
examples=[ |
|
["task1/temps/caption.jpg", "task1/srcs/src01.jpg", "task1/showcases/src01_seg.png"], |
|
] |
|
elif taskType=="2": |
|
examples=[ |
|
["task2/temps/caption.jpg", "task2/srcs/street.webp", "task2/showcases/out1.jpg"], |
|
] |
|
elif taskType=="7": |
|
examples=[ |
|
["task7/temps/task7.webp", "task7/srcs/305.jpg", "task7/showcases/task7out.jpg"], |
|
] |
|
elif taskType=="9": |
|
examples=[ |
|
["task9/temps/caption.jpg", "task9/srcs/use1.jpg", "task9/showcases/show0.jpg"], |
|
["task9/temps/caption.jpg", "task9/srcs/use2.jpg", "task9/showcases/show1.webp"], |
|
["task9/temps/caption.jpg", "task9/srcs/use3.jpg", "task9/showcases/show2.webp"], |
|
] |
|
|
|
for i in range(len(examples)): |
|
for j in range(len(examples[i])): |
|
examples[i][j] = os.path.join(data_dir, examples[i][j]) |
|
assert os.path.exists(examples[i][j]), examples[i][j] |
|
return examples |
|
|
|
def get_result_example(cloth_id, pose_id): |
|
result_dir = os.path.join(data_dir, 'ResultImgs') |
|
res_path = os.path.join(result_dir, f"{cloth_id}_{pose_id}.jpg") |
|
return res_path |
|
|
|
def upload_user_img_mask(clientIp, img, mask=None, taskType='1'): |
|
if taskType in ['8', '9']: |
|
return "img", '' |
|
timeId = int( str(time.time()).replace(".", "") )+random.randint(1000, 9999) |
|
fileName = clientIp.replace(".", "")+str(timeId)+".jpg" |
|
local_path = os.path.join(tmpFolder, fileName) |
|
filemName = clientIp.replace(".", "")+str(timeId)+"_m.jpg" |
|
localm_path = os.path.join(tmpFolder, filemName) |
|
cv2.imwrite(local_path, img[:,:,::-1].astype(np.uint8)) |
|
if mask is not None: |
|
cv2.imwrite(localm_path, mask) |
|
params = {'token':uploadToken, 'input1':fileName, 'input2':filemName} |
|
session = requests.session() |
|
ret = requests.post(f"{apiUrl}/upload", data=json.dumps(params)) |
|
res = "" |
|
|
|
head_dict = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png'} |
|
ftype = (os.path.basename(local_path).split(".")[-1]).lower() |
|
ctype = head_dict[ftype] |
|
headers = {"Content-Type": ctype} |
|
|
|
uploadm_url = '' |
|
if ret.status_code==200: |
|
if 'upload1' in ret.json(): |
|
upload_url = ret.json()['upload1'] |
|
with open(local_path, 'rb') as file: |
|
response = requests.put(upload_url, data=file, headers=headers) |
|
if response.status_code == 200: |
|
res = upload_url |
|
else: |
|
print(response) |
|
if mask is not None: |
|
uploadm_url = ret.json()['upload2'] |
|
with open(localm_path, 'rb') as file: |
|
response = requests.put(uploadm_url, data=file, headers=headers) |
|
if response.status_code == 200: |
|
pass |
|
else: |
|
uploadm_url = '' |
|
print(response) |
|
else: |
|
print(ret.json(), ret.status_code, 'call upload failed') |
|
if os.path.exists(local_path): os.remove(local_path) |
|
if os.path.exists(localm_path): os.remove(localm_path) |
|
return res, uploadm_url |
|
|
|
|
|
def publicSelfitTask(image, mask, temp_image, caption_text, param4_text, param5_text): |
|
temp_name = os.path.basename(temp_image).split('.')[0] |
|
params = {'openId':openId, 'apiKey':apiKey, 'image':image, 'mask':mask, |
|
"image_type":"2", "task_type":taskType, 'param1':temp_name, |
|
'param2':str(caption_text), 'param3':"1", 'param4':param4_text, 'param5':param5_text} |
|
session = requests.session() |
|
ret = requests.post(f"{apiUrl}/public", data=json.dumps(params)) |
|
print(ret) |
|
if ret.status_code==200: |
|
if 'id' in ret.json(): |
|
|
|
return ret.json()['id'] |
|
else: |
|
print(ret.json(), ret.status_code, 'call public failed') |
|
|
|
def getTaskRes(taskId, taskType): |
|
params = {'id':taskId, 'task_type':taskType} |
|
session = requests.session() |
|
ret = requests.post(f"{apiUrl}/status", data=json.dumps(params)) |
|
if ret.status_code==200: |
|
if 'status' in ret.json(): |
|
return ret.json() |
|
else: |
|
print(ret.json(), ret.status_code, 'call status failed') |
|
return None |
|
|
|
@func_timeout.func_set_timeout(10) |
|
def check_region(ip): |
|
session = requests.session() |
|
ret = requests.get(f"https://webapi-pc.meitu.com/common/ip_location?ip={ip}") |
|
for k in ret.json()['data']: |
|
nat = ret.json()['data'][k]['nation'] |
|
if nat in Regions: |
|
print(nat, 'invalid') |
|
return False |
|
else: |
|
print(nat, 'valid') |
|
return True |
|
def check_region_warp(ip): |
|
try: |
|
return check_region(ip) |
|
except Exception as e: |
|
print(e) |
|
return True |
|
|