hzxie commited on
Commit
8b8b671
β€’
1 Parent(s): 0222435

Switch to Gradio.

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. ARTICLE.md +20 -0
  3. README.md +4 -4
  4. app.py +78 -30
  5. assets/style.css +7 -0
  6. requirements.txt +7 -2
.gitignore CHANGED
@@ -179,3 +179,5 @@ configs/
179
  data/
180
  notebooks/
181
  output/
 
 
 
179
  data/
180
  notebooks/
181
  output/
182
+ flagged/
183
+ *.pth
ARTICLE.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##### Citation πŸ“
2
+
3
+ If our work is useful for your research, please consider citing:
4
+
5
+ ```bibtex
6
+ @inproceedings{xie2024citydreamer,
7
+ title = {City{D}reamer: Compositional Generative Model of Unbounded 3{D} Cities},
8
+ author = {Xie, Haozhe and
9
+ Chen, Zhaoxi and
10
+ Hong, Fangzhou and
11
+ Liu, Ziwei},
12
+ booktitle = {CVPR},
13
+ year = {2024}
14
+ }
15
+ ```
16
+
17
+ ##### License πŸ“‹
18
+
19
+ This project is licensed under [S-Lab License 1.0](https://huggingface.co/hzxie/city-dreamer/blob/main/LICENSE).
20
+ Redistribution and use for non-commercial purposes should follow this license.
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: CityDreamer
3
  emoji: πŸ™οΈ
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.31.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
1
  ---
2
  title: CityDreamer
3
  emoji: πŸ™οΈ
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.41.0
8
  app_file: app.py
9
  pinned: false
10
  license: other
app.py CHANGED
@@ -4,54 +4,102 @@
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
- # @Last Modified at: 2024-03-02 22:33:17
8
  # @Email: [email protected]
9
 
10
- import streamlit as st
 
 
 
 
 
 
 
11
 
12
- from PIL import Image
 
13
 
14
- LOGGER = st.logger.get_logger(__name__)
 
 
 
15
 
16
 
17
  def setup_runtime_env():
18
- pass
 
 
 
 
 
19
 
20
 
21
  def get_models():
22
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def get_generated_city(radius, altitude, azimuth):
26
- pass
27
 
28
 
29
- def main(fgm, bgm):
30
- st.set_page_config(
31
- page_title="CityDreamer Demo",
32
- page_icon="πŸ™οΈ",
33
- )
34
- # Main
35
- st.write("# CityDreamer Minimal Demo πŸ™οΈ")
36
  with open("README.md", "r") as f:
37
  markdown = f.read()
38
- st.markdown(markdown[markdown.rfind("---") :])
39
- imgbox = st.empty()
 
 
 
40
 
41
- # Sidebar
42
- st.sidebar.header("CityDreamer Settings")
43
- radius = st.sidebar.slider("Camera Radius (m)", 128, 512, 320, 5)
44
- altitude = st.sidebar.slider("Camera Altitude (m)", 256, 512, 384, 5)
45
- azimuth = st.sidebar.slider("Camera Azimuth (Β°)", 0, 360, 180, 5)
46
- if st.sidebar.button("Generate", type="primary"):
47
- img = get_generated_city(radius, altitude, azimuth)
48
- imgbox.image(img, caption="CityDreamer Generation")
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  if __name__ == "__main__":
52
- LOGGER.info("Setting up runtime environment...")
53
- setup_runtime_env()
54
- fgm, bgm = get_models()
55
-
56
- LOGGER.info("Starting the main application...")
57
- main(fgm, bgm)
 
 
 
 
4
  # @Author: Haozhe Xie
5
  # @Date: 2024-03-02 16:30:00
6
  # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-03-03 10:25:43
8
  # @Email: [email protected]
9
 
10
+ import logging
11
+ import os
12
+ import torch
13
+ import gradio as gr
14
+ import subprocess
15
+ import urllib.request
16
+ import ssl
17
+ import sys
18
 
19
+ # Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
 
22
+ sys.path.append(os.path.join(os.path.dirname(__file__), "citydreamer"))
23
+ # Import CityDreamer modules
24
+ # import citydreamer.model
25
+ # import citydreamer.inference
26
 
27
 
28
  def setup_runtime_env():
29
+ subprocess.call(["pip", "freeze"])
30
+ ext_dir = os.path.join(os.path.dirname(__file__), "citydreamer", "extensions")
31
+ for e in os.listdir(ext_dir):
32
+ if not os.path.isdir(e):
33
+ continue
34
+ subprocess.call(["pip", "install", "."], workdir=os.path.join(ext_dir, e))
35
 
36
 
37
  def get_models():
38
+ if not os.path.exists("CityDreamer-Fgnd.pth"):
39
+ urllib.request.urlretrieve(
40
+ "https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Fgnd.pth",
41
+ "CityDreamer-Fgnd.pth",
42
+ )
43
+ if not os.path.exists("CityDreamer-Bgnd.pth"):
44
+ urllib.request.urlretrieve(
45
+ "https://huggingface.co/hzxie/city-dreamer/resolve/main/CityDreamer-Bgnd.pth",
46
+ "CityDreamer-Bgnd.pth",
47
+ )
48
+
49
+ bgm_ckpt = torch.load("CityDreamer-Bgnd.pth")
50
+ fgm_ckpt = torch.load("CityDreamer-Fgnd.pth")
51
+ bgm = citydreamer.model.GanCraftGenerator(bgm_ckpt["cfg"])
52
+ fgm = citydreamer.model.GanCraftGenerator(fgm_ckpt["cfg"])
53
+ if torch.cuda.is_available():
54
+ fgm = torch.nn.DataParallel(fgm).cuda().eval()
55
+ bgm = torch.nn.DataParallel(bgm).cuda().eval()
56
+
57
+ return bgm, fgm
58
 
59
 
60
  def get_generated_city(radius, altitude, azimuth):
61
+ print(radius, altitude, azimuth)
62
 
63
 
64
+ def main(debug):
65
+ title = "CityDreamer Demo πŸ™οΈ"
 
 
 
 
 
66
  with open("README.md", "r") as f:
67
  markdown = f.read()
68
+ desc = markdown[markdown.rfind("---") + 3:]
69
+ with open("ARTICLE.md", "r") as f:
70
+ arti = f.read()
71
+ with open("assets/style.css") as f:
72
+ css = f.read()
73
 
74
+ app = gr.Interface(
75
+ get_generated_city,
76
+ [
77
+ gr.Slider(
78
+ 128, 512, value=320, step=5, label="Camera Radius (m)"
79
+ ),
80
+ gr.Slider(
81
+ 256, 512, value=384, step=5, label="Camera Altitude (m)"
82
+ ),
83
+ gr.Slider(0, 360, value=180, step=5, label="Camera Azimuth (Β°)"),
84
+ ],
85
+ [gr.Image(type="numpy", label="Generated City")],
86
+ title=title,
87
+ description=desc,
88
+ article=arti,
89
+ allow_flagging="never",
90
+ css=css,
91
+ )
92
+ app.queue(api_open=False)
93
+ app.launch(debug=debug)
94
 
95
 
96
  if __name__ == "__main__":
97
+ logging.basicConfig(
98
+ format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
99
+ )
100
+ logging.info("Compile CUDA extensions...")
101
+ # setup_runtime_env()
102
+ logging.info("Downloading pretrained models...")
103
+ # fgm, bgm = get_models()
104
+ logging.info("Starting the main application...")
105
+ main(os.getenv("DEBUG") == "1")
assets/style.css ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ display: block;
3
+ }
4
+
5
+ p img {
6
+ display: inline-block;
7
+ }
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
- requests
2
- streamlit
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.0
3
+ torchvision
4
+
5
+ numpy
6
+ opencv-python
7
+ gradio