Zaiiida commited on
Commit
708d74e
1 Parent(s): c0af06f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -80
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import time # Для эмуляции времени загрузки
3
  import tempfile
4
  import numpy as np
5
  import torch
@@ -7,10 +7,10 @@ from PIL import Image
7
  from tsr.system import TSR
8
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
9
 
10
- # Проверяем наличие GPU
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
- # Загружаем модель
14
  model = TSR.from_pretrained(
15
  "stabilityai/TripoSR",
16
  config_name="config.yaml",
@@ -19,14 +19,12 @@ model = TSR.from_pretrained(
19
  model.renderer.set_chunk_size(131072)
20
  model.to(device)
21
 
22
-
23
- # Функция для проверки изображения
24
  def check_input_image(input_image):
25
  if input_image is None:
26
  raise gr.Error("No image uploaded!")
27
 
28
-
29
- # Функция обработки изображения
30
  def preprocess(input_image, do_remove_background, foreground_ratio):
31
  def fill_background(image):
32
  image = np.array(image).astype(np.float32) / 255.0
@@ -45,74 +43,65 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
45
  image = fill_background(image)
46
  return image
47
 
48
-
49
- # Функция генерации 3D модели
50
  def generate(image):
51
- time.sleep(3) # Эмуляция времени обработки
52
  scene_codes = model(image, device=device)
53
  mesh = model.extract_mesh(scene_codes)[0]
54
  mesh = to_gradio_3d_orientation(mesh)
55
  mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
56
  mesh.export(mesh_path2.name)
57
  return mesh_path2.name
 
 
58
  def start_loading(loader_id):
59
  return f"<script>document.getElementById('{loader_id}').style.display = 'block';</script>"
60
 
61
  def stop_loading(loader_id):
62
  return f"<script>document.getElementById('{loader_id}').style.display = 'none';</script>"
63
 
64
- # Настройка темы и CSS
65
- class CustomTheme(gr.themes.Base):
66
- def __init__(self):
67
- super().__init__()
68
- self.primary_hue = "#191a1e"
69
- self.background_fill_primary = "#191a1e"
70
- self.background_fill_secondary = "#191a1e"
71
- self.background_fill_tertiary = "#191a1e"
72
- self.text_color_primary = "#FFFFFF"
73
- self.text_color_secondary = "#FFFFFF"
74
- self.text_color_tertiary = "#FFFFFF"
75
- self.input_background_fill = "#191a1e"
76
- self.input_text_color = "#FFFFFF"
77
-
78
-
79
  css = """
80
- /* Общий стиль лоадеров */
81
  .loader {
82
- display: none; /* Скрыт по умолчанию */
83
  position: absolute;
84
  top: 50%;
85
  left: 50%;
86
  transform: translate(-50%, -50%);
87
  width: 40px;
88
  height: 40px;
89
- border: 5px solid #f3f3f3; /* Белый круг */
90
- border-top: 5px solid #5271FF; /* Синий круг */
91
  border-radius: 50%;
92
  animation: spin 1s linear infinite;
93
  }
94
 
95
- /* Анимация вращения */
96
  @keyframes spin {
97
- 0% {
98
- transform: translate(-50%, -50%) rotate(0deg);
99
- }
100
- 100% {
101
- transform: translate(-50%, -50%) rotate(360deg);
102
- }
103
  }
104
 
105
- /* Добавить позицию для контейнеров */
106
  #image-container,
107
  #process-container,
108
  #generate-container {
109
- position: relative; /* Для размещения лоадера внутри */
 
 
 
 
 
 
 
 
110
  }
111
  """
112
 
113
- # Интерфейс
114
  with gr.Blocks(css=css) as demo:
115
  with gr.Column():
 
116
  with gr.Row(elem_id="image-container"):
117
  input_image = gr.Image(
118
  label="Upload Image",
@@ -129,7 +118,8 @@ with gr.Blocks(css=css) as demo:
129
  height=300,
130
  )
131
  loading_bar_image = gr.HTML("<div class='loader' id='image-loader'></div>")
132
-
 
133
  with gr.Row(elem_id="process-container"):
134
  foreground_ratio = gr.Slider(
135
  label="Foreground Ratio",
@@ -140,7 +130,8 @@ with gr.Blocks(css=css) as demo:
140
  )
141
  do_remove_background = gr.Checkbox(label="Remove Background", value=True)
142
  loading_bar_process = gr.HTML("<div class='loader' id='process-loader'></div>")
143
-
 
144
  with gr.Row(elem_id="generate-container"):
145
  submit = gr.Button("Generate", elem_classes="generate-button")
146
  output_model = gr.Model3D(
@@ -149,46 +140,47 @@ with gr.Blocks(css=css) as demo:
149
  elem_classes="gr-model3d-container",
150
  )
151
  loading_bar_generate = gr.HTML("<div class='loader' id='generate-loader'></div>")
152
- # Обновленная цепочка действий для кнопки submit
153
- submit.click(
154
- fn=lambda: start_loading('image-loader'), # Показать лоадер для загрузки изображения
155
- inputs=[],
156
- outputs=[loading_bar_image]
157
- ).then(
158
- fn=check_input_image, # Проверить изображение
159
- inputs=[input_image],
160
- outputs=[]
161
- ).then(
162
- fn=lambda: stop_loading('image-loader'), # Скрыть лоадер для загрузки
163
- inputs=[],
164
- outputs=[loading_bar_image]
165
- ).then(
166
- fn=lambda: start_loading('process-loader'), # Показать лоадер для обработки
167
- inputs=[],
168
- outputs=[loading_bar_process]
169
- ).then(
170
- fn=preprocess, # Обработка изображения
171
- inputs=[input_image, do_remove_background, foreground_ratio],
172
- outputs=[processed_image]
173
- ).then(
174
- fn=lambda: stop_loading('process-loader'), # Скрыть лоадер для обработки
175
- inputs=[],
176
- outputs=[loading_bar_process]
177
- ).then(
178
- fn=lambda: start_loading('generate-loader'), # Показать лоадер для генерации
179
- inputs=[],
180
- outputs=[loading_bar_generate]
181
- ).then(
182
- fn=generate, # Генерация модели
183
- inputs=[processed_image],
184
- outputs=[output_model]
185
- ).then(
186
- fn=lambda: stop_loading('generate-loader'), # Скрыть лоадер для генерации
187
- inputs=[],
188
- outputs=[loading_bar_generate]
189
- )
190
-
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  demo.launch(
193
  server_name="0.0.0.0",
194
  server_port=7860,
 
1
  import gradio as gr
2
+ import time # For simulating loading times
3
  import tempfile
4
  import numpy as np
5
  import torch
 
7
  from tsr.system import TSR
8
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
9
 
10
+ # Check for GPU availability
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load the model
14
  model = TSR.from_pretrained(
15
  "stabilityai/TripoSR",
16
  config_name="config.yaml",
 
19
  model.renderer.set_chunk_size(131072)
20
  model.to(device)
21
 
22
+ # Function to check image input
 
23
  def check_input_image(input_image):
24
  if input_image is None:
25
  raise gr.Error("No image uploaded!")
26
 
27
+ # Image preprocessing
 
28
  def preprocess(input_image, do_remove_background, foreground_ratio):
29
  def fill_background(image):
30
  image = np.array(image).astype(np.float32) / 255.0
 
43
  image = fill_background(image)
44
  return image
45
 
46
+ # 3D model generation
 
47
  def generate(image):
48
+ time.sleep(3) # Simulate processing time
49
  scene_codes = model(image, device=device)
50
  mesh = model.extract_mesh(scene_codes)[0]
51
  mesh = to_gradio_3d_orientation(mesh)
52
  mesh_path2 = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
53
  mesh.export(mesh_path2.name)
54
  return mesh_path2.name
55
+
56
+ # Loading bar controls
57
  def start_loading(loader_id):
58
  return f"<script>document.getElementById('{loader_id}').style.display = 'block';</script>"
59
 
60
  def stop_loading(loader_id):
61
  return f"<script>document.getElementById('{loader_id}').style.display = 'none';</script>"
62
 
63
+ # Custom CSS and theme
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  css = """
65
+ /* Loaders inside their respective rows */
66
  .loader {
67
+ display: none;
68
  position: absolute;
69
  top: 50%;
70
  left: 50%;
71
  transform: translate(-50%, -50%);
72
  width: 40px;
73
  height: 40px;
74
+ border: 5px solid #f3f3f3;
75
+ border-top: 5px solid #5271FF;
76
  border-radius: 50%;
77
  animation: spin 1s linear infinite;
78
  }
79
 
 
80
  @keyframes spin {
81
+ 0% { transform: translate(-50%, -50%) rotate(0deg); }
82
+ 100% { transform: translate(-50%, -50%) rotate(360deg); }
 
 
 
 
83
  }
84
 
85
+ /* Loader container positions */
86
  #image-container,
87
  #process-container,
88
  #generate-container {
89
+ position: relative;
90
+ }
91
+
92
+ /* Button styling remains as original */
93
+ .generate-button {
94
+ background-color: #5271FF !important;
95
+ color: #FFFFFF !important;
96
+ border: none;
97
+ font-weight: bold;
98
  }
99
  """
100
 
101
+ # Gradio Interface
102
  with gr.Blocks(css=css) as demo:
103
  with gr.Column():
104
+ # Row 1: Image upload
105
  with gr.Row(elem_id="image-container"):
106
  input_image = gr.Image(
107
  label="Upload Image",
 
118
  height=300,
119
  )
120
  loading_bar_image = gr.HTML("<div class='loader' id='image-loader'></div>")
121
+
122
+ # Row 2: Processing options
123
  with gr.Row(elem_id="process-container"):
124
  foreground_ratio = gr.Slider(
125
  label="Foreground Ratio",
 
130
  )
131
  do_remove_background = gr.Checkbox(label="Remove Background", value=True)
132
  loading_bar_process = gr.HTML("<div class='loader' id='process-loader'></div>")
133
+
134
+ # Row 3: Generate button and 3D model
135
  with gr.Row(elem_id="generate-container"):
136
  submit = gr.Button("Generate", elem_classes="generate-button")
137
  output_model = gr.Model3D(
 
140
  elem_classes="gr-model3d-container",
141
  )
142
  loading_bar_generate = gr.HTML("<div class='loader' id='generate-loader'></div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Submit button functionality
145
+ submit.click(
146
+ fn=lambda: start_loading("image-loader"),
147
+ inputs=[],
148
+ outputs=[loading_bar_image],
149
+ ).then(
150
+ fn=check_input_image,
151
+ inputs=[input_image],
152
+ outputs=[]
153
+ ).then(
154
+ fn=lambda: stop_loading("image-loader"),
155
+ inputs=[],
156
+ outputs=[loading_bar_image],
157
+ ).then(
158
+ fn=lambda: start_loading("process-loader"),
159
+ inputs=[],
160
+ outputs=[loading_bar_process],
161
+ ).then(
162
+ fn=preprocess,
163
+ inputs=[input_image, do_remove_background, foreground_ratio],
164
+ outputs=[processed_image],
165
+ ).then(
166
+ fn=lambda: stop_loading("process-loader"),
167
+ inputs=[],
168
+ outputs=[loading_bar_process],
169
+ ).then(
170
+ fn=lambda: start_loading("generate-loader"),
171
+ inputs=[],
172
+ outputs=[loading_bar_generate],
173
+ ).then(
174
+ fn=generate,
175
+ inputs=[processed_image],
176
+ outputs=[output_model],
177
+ ).then(
178
+ fn=lambda: stop_loading("generate-loader"),
179
+ inputs=[],
180
+ outputs=[loading_bar_generate],
181
+ )
182
+
183
+ # Launch app
184
  demo.launch(
185
  server_name="0.0.0.0",
186
  server_port=7860,