Spaces:
Sleeping
Sleeping
File size: 4,757 Bytes
c2164fe 28f074f c2164fe 79df973 8b8b671 79df973 8b8b671 28f074f 5d31697 8b8b671 79df973 8b8b671 79df973 c2164fe 8b8b671 79df973 c2164fe 28f074f f40b380 28f074f f40b380 28f074f f40b380 8b8b671 28f074f f40b380 c2164fe 28f074f 79df973 8b8b671 79df973 8b8b671 28f074f 79df973 8b8b671 79df973 e60fd27 79df973 8b8b671 79df973 28f074f 79df973 c2164fe 5d31697 28f074f 79df973 28f074f e60fd27 79df973 c2164fe 8b8b671 c2164fe 79df973 8b8b671 c2164fe 8b8b671 3d1b36b 8b8b671 c2164fe 8b8b671 79df973 f8ec29a 79df973 8b8b671 79df973 8b8b671 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# -*- coding: utf-8 -*-
#
# @File: app.py
# @Author: Haozhe Xie
# @Date: 2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-09-22 10:31:28
# @Email: [email protected]
import gradio as gr
import logging
import numpy as np
import os
import spaces
import ssl
import subprocess
import sys
import torch
import urllib.request
from PIL import Image
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
def _get_output(cmd):
try:
return subprocess.check_output(cmd).decode("utf-8")
except Exception as ex:
logging.exception(ex)
return None
def setup_runtime_env():
logging.info("Python Version: %s" % _get_output(["python", "--version"]))
logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
# Install Pre-compiled CUDA extensions
ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
for e in os.listdir(ext_dir):
logging.info("Installing Extensions from %s" % e)
subprocess.call(
["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
)
# Compile CUDA extensions
# ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
# for e in os.listdir(ext_dir):
# if os.path.isdir(os.path.join(ext_dir, e)):
# subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))
def get_models(file_name):
import citydreamer.model
if not os.path.exists(file_name):
urllib.request.urlretrieve(
"https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
file_name,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(file_name, map_location=torch.device(device))
model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
if torch.cuda.is_available():
model = torch.nn.DataParallel(model).cuda().eval()
model.load_state_dict(ckpt["gancraft_g"], strict=False)
return model
def get_city_layout():
hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
return hf, seg
@spaces.GPU
def get_generated_city(
radius, altitude, azimuth, map_center, progress=gr.Progress(track_tqdm=True)
):
logging.info("CUDA is available: %s" % torch.cuda.is_available())
logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
# The import must be done after CUDA extension compilation
import citydreamer.inference
return citydreamer.inference.generate_city(
get_generated_city.fgm.to("cuda"),
get_generated_city.bgm.to("cuda"),
get_generated_city.hf.copy(),
get_generated_city.seg.copy(),
map_center,
map_center,
radius,
altitude,
azimuth,
)
def main(debug):
title = "CityDreamer Demo 🏙️"
with open("README.md", "r") as f:
markdown = f.read()
desc = markdown[markdown.rfind("---") + 3 :]
with open("ARTICLE.md", "r") as f:
arti = f.read()
app = gr.Interface(
get_generated_city,
[
gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
],
[gr.Image(type="numpy", label="Generated City")],
title=title,
description=desc,
article=arti,
allow_flagging="never",
)
app.queue(api_open=False)
app.launch(debug=debug)
if __name__ == "__main__":
logging.basicConfig(
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
)
logging.info("Compiling CUDA extensions...")
setup_runtime_env()
logging.info("Downloading pretrained models...")
fgm = get_models("CityDreamer-Fgnd.pth")
bgm = get_models("CityDreamer-Bgnd.pth")
get_generated_city.fgm = fgm
get_generated_city.bgm = bgm
logging.info("Loading New York city layout to RAM...")
hf, seg = get_city_layout()
get_generated_city.hf = hf
get_generated_city.seg = seg
logging.info("Starting the main application...")
main(os.getenv("DEBUG") == "1")
|