AI-Object-Removal / utils.py
selfitcamera
init
f30af9d
raw
history blame
11.2 kB
from config import *
apiUrl = os.environ['apiUrl']
uploadToken = os.environ['uploadToken']
openId = os.environ['openId']
apiKey = os.environ['apiKey']
Regions = os.environ['Regions']
tokenUrl = os.environ['tokenUrl']
LimitTask = int(os.environ['LimitTask'])
proj_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(proj_dir, 'Datas')
tmpFolder = os.path.join(proj_dir, 'tmp')
os.makedirs(tmpFolder, exist_ok=True)
def load_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
def save_pkl(data, path, reweite=False):
os.makedirs(os.path.dirname(path), exist_ok=True)
if not os.path.exists(path) or reweite: # 不存在或者强制重写
with open(path,'wb') as file:
pickle.dump(data, file, protocol=4)
return data
else:
load_data = load_pkl(path)
for k in data:
load_data[k] = data[k]
save_pkl(load_data, path, reweite=True)
return load_data
def checkToken(token):
'''
输入token,检查token是否合法
'''
return False
params = {'token':str(token)}
session = requests.session()
ret = requests.post(f"{tokenUrl}", data=json.dumps(params))
print(ret)
res = False
if ret.status_code==200:
if 'res' in ret.json():
res = (ret.json()['res'])>0
else:
print(ret.json(), ret.status_code, 'call token failed')
return res
class UserRecorder(object):
def __init__(self, ):
super(UserRecorder, self).__init__()
record_dir = os.path.join(data_dir, f'UserRecord_{taskType}')
self.ip_dir = os.path.join(record_dir, 'Ips')
self.token_dir = os.path.join(record_dir, 'Tokens')
os.makedirs(self.ip_dir, exist_ok=True)
os.makedirs(self.token_dir, exist_ok=True)
def save_record(self, taskRes, ip="", token=""):
if len(ip)==0 and len(token)==0: return
if len(token)==0: # token优先
record_path = os.path.join(self.ip_dir, f'{ip}.pkl')
else:
record_path = os.path.join(self.token_dir, f'{token}.pkl')
taskId = taskRes['id']
status = taskRes['status']
if 'output' in taskRes:
input1 = taskRes['output']['job_results']['input1']
output1 = taskRes['output']['job_results']['output1']
else:
input1, output1 = None, None
data = OrderedDict()
data[taskId] = {'input1':input1, 'output1':output1, 'status':status, }
save_data = save_pkl(data, record_path, reweite=False)
return save_data
def check_record(self, ip="", token=""):
if len(token)>0:
token_valid = checkToken(token)
if token_valid:
return True, ""
else:
return False, "api key is invalid"
else:
_, total_n, _ = self.get_record(ip=ip, token=token)
if total_n>=LimitTask:
return False, no_more_attempts
else:
return True, ""
def get_record(self, ip="", token=""):
if len(ip)==0 and len(token)==0: return
if len(token)==0:
identity = ip
record_path = os.path.join(self.ip_dir, f'{ip}.pkl')
else:
identity = token
record_path = os.path.join(self.token_dir, f'{token}.pkl')
if os.path.exists(record_path):
record_data = load_pkl(record_path)
else:
record_data = {}
total_n = len(record_data)
success_n, fail_n, process_n = 0, 0, 0
shows = [None]*6
show_i = 0
for key in reversed(record_data):
status = record_data[key]['status']
if status in ['FAILED', 'CANCELLED', 'TIMED_OUT', ]:
fail_n += 1
elif status in ['COMPLETED', ]:
success_n += 1
if record_data[key]['input1'] is not None:
input1 = record_data[key]['input1']
output1 = record_data[key]['output1']
if show_i<=2:
shows[show_i*2] = f"<img src=\"{input1}\" >"
shows[show_i*2+1] = f"<img src=\"{output1}\" >"
show_i += 1
elif status in ['IN_QUEUE', 'IN_PROGRESS', 'IN_QUEUE', ]:
process_n += 1
msg = f"Dear {identity}, You have {total_n} tasks, {success_n} successed, {fail_n} failed, {process_n} processing, "
return shows, total_n, msg
def get_temps_examples(taskType):
temp_dir = os.path.join(data_dir, f'task{taskType}/temps')
examples = []
if not os.path.exists(temp_dir): return []
files = [f for f in sorted(os.listdir(temp_dir)) if '.' in f]
for f in files:
temp_name = f.split(".")[0]
if len(temp_name)==0: continue
temp_path = os.path.join(temp_dir, f)
examples.append([temp_path])
examples = examples[::-1]
return examples
def get_user_examples(taskType):
user_dir = os.path.join(data_dir, f'task{taskType}/srcs')
examples = []
if not os.path.exists(user_dir): return []
files = [f for f in sorted(os.listdir(user_dir)) if '.' in f]
for f in files:
user_id = f.split(".")[0]
if len(user_id)==0: continue
user_path = os.path.join(user_dir, f)
examples.append([user_path])
return examples
def get_showcase_examples(taskType):
examples = []
if taskType=="3":
examples=[
["task3/temps/flow-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_flower-water.jpg"],
["task3/temps/mountain-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_mountain-water.jpg"],
["task3/temps/rock-on-water.jpg", "task3/srcs/src01.jpg", "task3/showcases/src01_rock-on-water.jpg"],
]
elif taskType=="4":
examples=[
["task4/temps/Vivienne.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_vivienne.jpg"],
["task4/temps/Bella.jpg", "task4/srcs/src04.jpg", "task4/showcases/src04_balle.jpg"],
["task4/temps/Nia.jpg", "task4/srcs/src02.jpg", "task4/showcases/src02_nia.jpg"],
["task4/temps/Leo.jpg", "task4/srcs/src03.jpg", "task4/showcases/src03_male.jpg"],
]
elif taskType=="6":
examples=[
["task6/temps/niantu.jpg", "task6/srcs/src01.jpg", "task6/showcases/src01_niantu.jpg"],
["task6/temps/3d-shouban.jpg", "task6/srcs/src02.jpg", "task6/showcases/src02_shouban.jpg"],
]
elif taskType=="5":
examples=[
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_street.jpg"],
["task5/temps/caption.jpg", "task5/srcs/src01.jpg", "task5/showcases/src01_walk.jpg"],
]
elif taskType=="1":
examples=[
["task1/temps/caption.jpg", "task1/srcs/src01.jpg", "task1/showcases/src01_seg.png"],
]
elif taskType=="2":
examples=[
["task2/temps/caption.jpg", "task2/srcs/street.webp", "task2/showcases/out1.jpg"],
]
elif taskType=="7":
examples=[
["task2/temps/caption.jpg", "task2/srcs/street.webp", "task2/showcases/out1.jpg"],
]
for i in range(len(examples)):
for j in range(len(examples[i])):
examples[i][j] = os.path.join(data_dir, examples[i][j])
assert os.path.exists(examples[i][j]), examples[i][j]
return examples
def get_result_example(cloth_id, pose_id):
result_dir = os.path.join(data_dir, 'ResultImgs')
res_path = os.path.join(result_dir, f"{cloth_id}_{pose_id}.jpg")
return res_path
def upload_user_img_mask(clientIp, img, mask=None):
timeId = int( str(time.time()).replace(".", "") )+random.randint(1000, 9999)
fileName = clientIp.replace(".", "")+str(timeId)+".jpg"
local_path = os.path.join(tmpFolder, fileName)
filemName = clientIp.replace(".", "")+str(timeId)+"_m.jpg"
localm_path = os.path.join(tmpFolder, filemName)
cv2.imwrite(local_path, img[:,:,::-1].astype(np.uint8))
if mask is not None:
cv2.imwrite(localm_path, mask)
params = {'token':uploadToken, 'input1':fileName, 'input2':filemName}
session = requests.session()
ret = requests.post(f"{apiUrl}/upload", data=json.dumps(params))
res = ""
head_dict = {'jpg': 'image/jpeg', 'jpeg': 'image/jpeg', 'png': 'image/png'}
ftype = (os.path.basename(local_path).split(".")[-1]).lower()
ctype = head_dict[ftype]
headers = {"Content-Type": ctype}
uploadm_url = ''
if ret.status_code==200:
if 'upload1' in ret.json():
upload_url = ret.json()['upload1']
with open(local_path, 'rb') as file:
response = requests.put(upload_url, data=file, headers=headers)
if response.status_code == 200:
res = upload_url
else:
print(response)
if mask is not None:
uploadm_url = ret.json()['upload2']
with open(localm_path, 'rb') as file:
response = requests.put(uploadm_url, data=file, headers=headers)
if response.status_code == 200:
pass
else:
uploadm_url = ''
print(response)
else:
print(ret.json(), ret.status_code, 'call upload failed')
if os.path.exists(local_path): os.remove(local_path)
if os.path.exists(localm_path): os.remove(localm_path)
return res, uploadm_url
def publicSelfitTask(image, mask, temp_image, caption_text, param4_text, param5_text):
temp_name = os.path.basename(temp_image).split('.')[0]
params = {'openId':openId, 'apiKey':apiKey, 'image':image, 'mask':mask,
"image_type":"2", "task_type":taskType, 'param1':temp_name,
'param2':str(caption_text), 'param3':"1", 'param4':param4_text, 'param5':param5_text}
session = requests.session()
ret = requests.post(f"{apiUrl}/public", data=json.dumps(params))
print(ret)
if ret.status_code==200:
if 'id' in ret.json():
# print(ret.json())
return ret.json()['id']
else:
print(ret.json(), ret.status_code, 'call public failed')
def getTaskRes(taskId):
params = {'id':taskId}
session = requests.session()
ret = requests.post(f"{apiUrl}/status", data=json.dumps(params))
if ret.status_code==200:
if 'status' in ret.json():
return ret.json()
else:
print(ret.json(), ret.status_code, 'call status failed')
return None
@func_timeout.func_set_timeout(10)
def check_region(ip):
session = requests.session()
ret = requests.get(f"https://webapi-pc.meitu.com/common/ip_location?ip={ip}")
for k in ret.json()['data']:
nat = ret.json()['data'][k]['nation']
if nat in Regions:
print(nat, 'invalid')
return False
else:
print(nat, 'valid')
return True
def check_region_warp(ip):
try:
return check_region(ip)
except Exception as e:
print(e)
return True