File size: 4,757 Bytes
c2164fe
 
 
 
 
 
28f074f
c2164fe
 
79df973
8b8b671
79df973
8b8b671
28f074f
5d31697
8b8b671
79df973
8b8b671
79df973
 
 
 
c2164fe
8b8b671
 
 
79df973
c2164fe
 
28f074f
f40b380
28f074f
f40b380
 
 
28f074f
 
 
 
 
 
 
 
 
f40b380
8b8b671
28f074f
 
 
 
f40b380
 
 
 
 
c2164fe
28f074f
 
79df973
 
 
 
 
8b8b671
79df973
 
8b8b671
 
28f074f
 
79df973
8b8b671
79df973
 
e60fd27
79df973
8b8b671
79df973
 
28f074f
 
79df973
c2164fe
 
5d31697
 
 
 
28f074f
 
79df973
 
 
 
28f074f
 
e60fd27
 
 
 
79df973
 
 
 
c2164fe
 
8b8b671
 
c2164fe
 
79df973
8b8b671
 
c2164fe
8b8b671
 
 
3d1b36b
 
 
 
8b8b671
 
 
 
 
 
 
 
 
c2164fe
 
 
8b8b671
 
 
79df973
f8ec29a
79df973
8b8b671
79df973
 
 
 
 
 
 
 
 
 
8b8b671
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# -*- coding: utf-8 -*-
#
# @File:   app.py
# @Author: Haozhe Xie
# @Date:   2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-09-22 10:31:28
# @Email:  [email protected]

import gradio as gr
import logging
import numpy as np
import os

import spaces
import ssl
import subprocess
import sys
import torch
import urllib.request

from PIL import Image

# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import CityDreamer modules
sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))


def _get_output(cmd):
    try:
        return subprocess.check_output(cmd).decode("utf-8")
    except Exception as ex:
        logging.exception(ex)

    return None


def setup_runtime_env():
    logging.info("Python Version: %s" % _get_output(["python", "--version"]))
    logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
    logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))

    # Install Pre-compiled CUDA extensions
    ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
    for e in os.listdir(ext_dir):
        logging.info("Installing Extensions from %s" % e)
        subprocess.call(
            ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
        )
    # Compile CUDA extensions
    # ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
    # for e in os.listdir(ext_dir):
    #     if os.path.isdir(os.path.join(ext_dir, e)):
    #         subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))

    logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))


def get_models(file_name):
    import citydreamer.model

    if not os.path.exists(file_name):
        urllib.request.urlretrieve(
            "https://huggingface.co/hzxie/city-dreamer/resolve/main/%s" % file_name,
            file_name,
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt = torch.load(file_name, map_location=torch.device(device))
    model = citydreamer.model.GanCraftGenerator(ckpt["cfg"])
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda().eval()

    model.load_state_dict(ckpt["gancraft_g"], strict=False)
    return model


def get_city_layout():
    hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
    seg = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
    return hf, seg


@spaces.GPU
def get_generated_city(
    radius, altitude, azimuth, map_center, progress=gr.Progress(track_tqdm=True)
):
    logging.info("CUDA is available: %s" % torch.cuda.is_available())
    logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
    # The import must be done after CUDA extension compilation
    import citydreamer.inference

    return citydreamer.inference.generate_city(
        get_generated_city.fgm.to("cuda"),
        get_generated_city.bgm.to("cuda"),
        get_generated_city.hf.copy(),
        get_generated_city.seg.copy(),
        map_center,
        map_center,
        radius,
        altitude,
        azimuth,
    )


def main(debug):
    title = "CityDreamer Demo 🏙️"
    with open("README.md", "r") as f:
        markdown = f.read()
        desc = markdown[markdown.rfind("---") + 3 :]
    with open("ARTICLE.md", "r") as f:
        arti = f.read()

    app = gr.Interface(
        get_generated_city,
        [
            gr.Slider(128, 512, value=343, step=5, label="Camera Radius (m)"),
            gr.Slider(256, 512, value=296, step=5, label="Camera Altitude (m)"),
            gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
            gr.Slider(1440, 6752, value=3970, step=5, label="Map Center (px)"),
        ],
        [gr.Image(type="numpy", label="Generated City")],
        title=title,
        description=desc,
        article=arti,
        allow_flagging="never",
    )
    app.queue(api_open=False)
    app.launch(debug=debug)


if __name__ == "__main__":
    logging.basicConfig(
        format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
    )
    logging.info("Compiling CUDA extensions...")
    setup_runtime_env()

    logging.info("Downloading pretrained models...")
    fgm = get_models("CityDreamer-Fgnd.pth")
    bgm = get_models("CityDreamer-Bgnd.pth")
    get_generated_city.fgm = fgm
    get_generated_city.bgm = bgm

    logging.info("Loading New York city layout to RAM...")
    hf, seg = get_city_layout()
    get_generated_city.hf = hf
    get_generated_city.seg = seg

    logging.info("Starting the main application...")
    main(os.getenv("DEBUG") == "1")