Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# | |
# @File: app.py | |
# @Author: Haozhe Xie | |
# @Date: 2024-03-02 16:30:00 | |
# @Last Modified by: Haozhe Xie | |
# @Last Modified at: 2024-09-22 10:31:28 | |
# @Email: [email protected] | |
import gradio as gr | |
import logging | |
import numpy as np | |
import os | |
import spaces | |
import ssl | |
import subprocess | |
import sys | |
import torch | |
import urllib.request | |
from PIL import Image | |
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed | |
ssl._create_default_https_context = ssl._create_unverified_context | |
# Import CityDreamer modules | |
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer")) | |
def _get_output(cmd): | |
try: | |
return subprocess.check_output(cmd).decode("utf-8") | |
except Exception as ex: | |
logging.exception(ex) | |
return None | |
def setup_runtime_env(): | |
logging.info("Python Version: %s" % _get_output(["python", "--version"])) | |
logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"])) | |
logging.info("GCC Version: %s" % _get_output(["gcc", "--version"])) | |
# Install Pre-compiled CUDA extensions | |
ext_dir = os.path.join(os.path.dirname(__file__), "wheels") | |
for e in os.listdir(ext_dir): | |
logging.info("Installing Extensions from %s" % e) | |
subprocess.call( | |
["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT | |
) | |
# Compile CUDA extensions | |
# ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions") | |
# for e in os.listdir(ext_dir): | |
# if os.path.isdir(os.path.join(ext_dir, e)): | |
# subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e)) | |
logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"])) | |
def get_models(file_name): | |
import citydreamer.model | |
if not os.path.exists(file_name): | |
urllib.request.urlretrieve( | |
"https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name, | |
file_name, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
ckpt = torch.load(file_name, map_location=torch.device(device)) | |
model = citydreamer.model.GanCraftGenerator(ckpt["cfg"]) | |
if torch.cuda.is_available(): | |
model = torch.nn.DataParallel(model).cuda().eval() | |
model.load_state_dict(ckpt["gancraft_g"], strict=False) | |
return model | |
def get_city_layout(): | |
hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32) | |
seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32) | |
return hf, seg | |
def get_generated_city( | |
radius, altitude, azimuth, map_center, progress=gr.Progress(track_tqdm=True) | |
): | |
logging.info("CUDA is available: %s" % torch.cuda.is_available()) | |
logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda) | |
# The import must be done after CUDA extension compilation | |
import citydreamer.inference | |
return citydreamer.inference.generate_city( | |
get_generated_city.fgm.to("cuda"), | |
get_generated_city.bgm.to("cuda"), | |
get_generated_city.hf.copy(), | |
get_generated_city.seg.copy(), | |
map_center, | |
map_center, | |
radius, | |
altitude, | |
azimuth, | |
) | |
def main(debug): | |
title = "CityDreamer Demo 🏙️" | |
with open("README.md", "r") as f: | |
markdown = f.read() | |
desc = markdown[markdown.rfind("---") + 3 :] | |
with open("ARTICLE.md", "r") as f: | |
arti = f.read() | |
app = gr.Interface( | |
get_generated_city, | |
[ | |
gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"), | |
gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"), | |
gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"), | |
gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"), | |
], | |
[gr.Image(type="numpy", label="Generated City")], | |
title=title, | |
description=desc, | |
article=arti, | |
allow_flagging="never", | |
) | |
app.queue(api_open=False) | |
app.launch(debug=debug) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO | |
) | |
logging.info("Compiling CUDA extensions...") | |
setup_runtime_env() | |
logging.info("Downloading pretrained models...") | |
fgm = get_models("CityDreamer-Fgnd.pth") | |
bgm = get_models("CityDreamer-Bgnd.pth") | |
get_generated_city.fgm = fgm | |
get_generated_city.bgm = bgm | |
logging.info("Loading New York city layout to RAM...") | |
hf, seg = get_city_layout() | |
get_generated_city.hf = hf | |
get_generated_city.seg = seg | |
logging.info("Starting the main application...") | |
main(os.getenv("DEBUG") == "1") | |