LOHJC
added initial files
77a9008
raw
history blame
2.33 kB
import onnx
import onnxruntime as ort
import numpy as np
MODEL_PATH = r"./"
model_name = "animalImageGAN_full.onnx"
ONNX_MODEL_PATH = MODEL_PATH+model_name
onnx_model = onnx.load(ONNX_MODEL_PATH)
onnx.checker.check_model(onnx_model)
rng = np.random.default_rng()
desired_mean = 0
desired_variance = 1
generator_input_size = 50
latent_space_samples = np.random.rand(generator_input_size,1,1).astype(np.float32)
ort_sess = ort.InferenceSession(ONNX_MODEL_PATH)
import gradio as gr
def generateImage():
random_input = rng.random((generator_input_size, 1, 1),dtype=np.float32)
current_mean = np.mean(random_input)
current_variance = np.var(random_input)
scaled_values = (random_input - current_mean) / np.sqrt(current_variance)
random_input = scaled_values * np.sqrt(desired_variance) + desired_mean
outputs = ort_sess.run(None, {'input': random_input})
output = outputs[0]
denorm_output = np.clip((output * 0.5) + 0.5,0,1)
#print("i: {}, min:{},max:{}".format(i,denorm_output.min(),denorm_output.max()))
return denorm_output.transpose(1,2,0)
DESCRIPTION = "<div style='text-align:center'><h1 style='justify-content: center'>Animal Portrait Generator</h1>"
DESCRIPTION += "<p>This is a model trained by using DCGAN</p>"
DESCRIPTION += "<p>More details:</p>"
DESCRIPTION += "<ul><li><a href='https://medium.com/@jiachiewloh/dcgan-animal-image-generator-85e466fb6254'>Article</a></li>"
DESCRIPTION += "<li><a href='https://www.kaggle.com/code/jclohjc/animal-image-generator-dcgan'>Code</a></li></ul>"
DESCRIPTION += "</div>"
with gr.Blocks(css="#img_window {text-align:center; justify-content: center;}\
.image-container {margin: auto; height: 250px; width: 250px; !important}") as demo:
# with gr.Row():
# gr.Markdown(DESCRIPTION)
# with gr.Column():
# img_window = gr.Image(interactive=False,height=250,width=250)
# with gr.Row():
# gr.Button("Generate").click(fn=generateImage,outputs=img_window)
# gr.ClearButton().add(img_window)
gr.Markdown(DESCRIPTION)
img_window = gr.Image(interactive=False,elem_id="img_window")
with gr.Row():
gr.Button("Generate").click(fn=generateImage,outputs=img_window)
gr.ClearButton().add(img_window)
demo.launch()