model-swap-ai / app.py
kailashahirwar's picture
First Commit; code for demo app
475fa42
raw
history blame
12.3 kB
import os.path
import gradio as gr
import json
import requests
import time
from gradio_modal import Modal
from io import BytesIO
TRYON_SERVER_HOST = "https://prod.server.tryonlabs.ai"
TRYON_SERVER_PORT = "80"
if TRYON_SERVER_PORT == "80":
TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}"
else:
TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}:{TRYON_SERVER_PORT}"
TRYON_SERVER_API_URL = f"{TRYON_SERVER_URL}/api/v1/"
def start_model_swap(input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps):
# make a request to TryOn Server
# 1. create an experiment image
print("inputs:", input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps)
if input_image is None:
raise gr.Error("Select an image!")
if prompt is None or prompt == "":
raise gr.Error("Enter a prompt!")
token = load_token()
if token is None or token == "":
raise gr.Error("You need to login first!")
else:
login(token)
byte_io = BytesIO()
input_image.save(byte_io, 'png')
byte_io.seek(0)
r = requests.post(f"{TRYON_SERVER_API_URL}experiment_image/",
files={"image": (
'ei_image.png',
byte_io,
'image/png'
)},
data={
"type": "model",
"preprocess": "false"},
headers={
"Authorization": f"Bearer {token}"
})
# print(r.json())
if r.status_code == 200 or r.status_code == 201:
print("Experiment image created successfully", r.json())
res = r.json()
# 2 create an experiment
r2 = requests.post(f"{TRYON_SERVER_API_URL}experiment/",
data={
"model_id": res['id'],
"action": "model_swap",
"params": json.dumps({"prompt": prompt,
"guidance_scale": guidance_scale,
"strength": strength,
"num_inference_steps": inference_steps,
"seed": seed,
"garment_class": f"{cls} garment",
"negative_prompt": "(hands:1.15), disfigured, ugly, bad, immature"
", cartoon, anime, 3d, painting, b&w, (ugly),"
" (pixelated), watermark, glossy, smooth, "
"earrings, necklace",
"num_results": num_results})
},
headers={
"Authorization": f"Bearer {token}"
})
if r2.status_code == 200 or r2.status_code == 201:
# 3. keep checking the status of the experiment
res2 = r2.json()
print("Experiment created successfully", res2)
time.sleep(10)
experiment = res2['experiment']
status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
status_status = status['status']
while status_status == "running":
time.sleep(10)
status = fetch_experiment_status(experiment_id=experiment['id'], token=token)
status_status = status['status']
print(f"Current status: {status_status}")
if status['status'] == "success":
print("Experiment successful")
print(f"Results:{status['result_images']}")
return status['result_images']
elif status['status'] == "failed":
print("Experiment failed")
raise gr.Error("Experiment failed")
else:
print(f"Error: {r2.text}")
raise gr.Error(f"Failure: {r2.text}")
else:
print(f"Error: {r.text}")
raise gr.Error(f"Failure: {r.text}")
def fetch_experiment_status(experiment_id, token):
print(f"experiment id:{experiment_id}")
r3 = requests.get(f"{TRYON_SERVER_API_URL}experiment/{experiment_id}/",
headers={
"Authorization": f"Bearer {token}"
})
if r3.status_code == 200:
res = r3.json()
if res['status'] == "running":
return {"status": "running"}
elif res['status'] == "success":
experiment = r3.json()['experiment']
result_images = [f"{TRYON_SERVER_URL}/{experiment['result']['image_url']}"]
if len(experiment['results']) > 0:
for result in experiment['results']:
result_images.append(f"{TRYON_SERVER_URL}/{result['image_url']}")
return {"status": "success", "result_images": result_images}
elif res['status'] == "failed":
return {"status": "failed"}
else:
print(f"Error: {r3.text}")
return {"status": "failed"}
def get_user_credits(token):
if token == "":
return None
r = requests.get(f"{TRYON_SERVER_API_URL}user/get/", headers={
"Authorization": f"Bearer {token}"
})
if r.status_code == 200:
res = r.json()
return res['credits']
else:
print(f"Error: {r.text}")
return None
def load_token():
if os.path.exists(".token"):
with open(".token", "r") as f:
return json.load(f)['token']
else:
return None
def save_token(access_token):
if access_token != "":
with open(".token", "w") as f:
json.dump({"token": access_token}, f)
else:
raise gr.Error("No token provided!")
def is_logged_in():
loaded_token = load_token()
if loaded_token is None or loaded_token == "":
return False
else:
return True
def login(token):
print("logging in...")
# validate token
r = requests.post(f"{TRYON_SERVER_URL}/api/token/verify/", data={"token": token})
if r.status_code == 200:
save_token(token)
return True
else:
raise gr.Error("Login failed")
def logout():
print("logged out")
with open(".token", "w") as f:
json.dump({"token": ""}, f)
return [False, ""]
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#credits-col-container{
display:flex;
justify-content: right;
align-items: center;
font-size: 24px;
margin-right: 1rem;
}
#login-modal{
max-width: 728px;
margin: 0 auto;
margin-top: 1rem;
margin-bottom: 1rem;
}
#login-logout-btn{
display:inline;
max-width: 124px;
}
"""
with gr.Blocks(css=css) as demo:
print("is logged in:", is_logged_in())
logged_in = gr.State(is_logged_in())
if os.path.exists(".token"):
with open(".token", "r") as f:
user_token = gr.State(json.load(f)["token"])
else:
user_token = gr.State("")
with Modal(visible=False) as modal:
@gr.render(inputs=user_token)
def rerender1(user_token1):
with gr.Column(elem_id="login-modal"):
access_token = gr.Textbox(
label="Token",
lines=1,
value=user_token1,
type="password",
placeholder="Enter your access token here!",
info="Visit https://playground.tryonlabs.ai to retrieve your access token."
)
login_submit_btn = gr.Button("Login", scale=1, variant='primary')
login_submit_btn.click(
fn=lambda access_token: (login(access_token), Modal(visible=False), access_token),
inputs=[access_token], outputs=[logged_in, modal, user_token],
concurrency_limit=1)
with gr.Row(elem_id="col-container"):
with gr.Column():
gr.Markdown(f"""
# Model Swap AI
## by TryOn Labs (https://www.tryonlabs.ai)
Swap a human model with a artificial model generated by Artificial Model while keeping the garment intact.
""")
@gr.render(inputs=logged_in)
def rerender(is_logged_in):
with gr.Column():
if not is_logged_in:
with gr.Row(elem_id="credits-col-container"):
login_btn = gr.Button(value="Login", variant='primary', elem_id="login-logout-btn", size="sm")
login_btn.click(lambda: Modal(visible=True), None, modal)
else:
user_credits = get_user_credits(load_token())
print("user_credits", user_credits)
gr.HTML(f"""<div><p id="credits-col-container">Your Credits:
{user_credits if user_credits is not None else "0"}</p>
<p style="text-align: right;">Visit <a href="https://playground.tryonlabs.ai">
TryOn AI Playground</a> to acquire more credits</p></div>""")
with gr.Row(elem_id="credits-col-container"):
logout_btn = gr.Button(value="Logout", scale=1, variant='primary', size="sm",
elem_id="login-logout-btn")
logout_btn.click(fn=logout, inputs=None, outputs=[logged_in, user_token], concurrency_limit=1)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Original image", type='pil', height="400px", show_label=True)
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Enter your prompt here!",
)
dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Retain garment",
info="Select the garment type you want to retain in the generated image!")
gallery = gr.Gallery(
label="Generated images", show_label=True, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height="auto")
# output_image = gr.Image(label="Swapped model", type='pil', height="400px", show_label=True,
# show_download_button=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Number(label="Seed", value=-1, interactive=True, minimum=-1)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5, interactive=True, minimum=0.0,
maximum=10.0,
step=0.1)
num_results = gr.Number(label="Number of results", value=2, minimum=1, maximum=5)
with gr.Row():
strength = gr.Slider(0.00, 1.00, value=0.99, label="Strength",
info="Choose between 0.00 and 1.00", step=0.01, interactive=True)
inference_steps = gr.Number(label="Inference Steps", value=20, interactive=True, minimum=1, step=1)
with gr.Row():
submit_button = gr.Button("Submit", variant='primary', scale=1)
reset_button = gr.ClearButton(value="Reset", scale=1)
gr.on(
triggers=[submit_button.click],
fn=start_model_swap,
inputs=[input_image, prompt, dropdown, seed, guidance_scale, num_results, strength, inference_steps],
outputs=[gallery]
)
reset_button.click(
fn=lambda: (None, None, "upper", None, -1, 7.5, 2, 0.99, 20),
inputs=[],
outputs=[input_image, prompt, dropdown, gallery, seed, guidance_scale,
num_results, strength, inference_steps],
concurrency_limit=1,
)
if __name__ == '__main__':
demo.launch()