Spaces:
Paused
Paused
Upload 4 files
Browse files- .gitattributes +1 -0
- README.md +6 -5
- app.py +182 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
1 |
---
|
2 |
+
title: TripoSR
|
3 |
+
emoji: 🐳
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.20.1
|
8 |
+
python_version: 3.10.13
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
license: mit
|
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import boto3
|
4 |
+
import json
|
5 |
+
import shlex
|
6 |
+
import subprocess
|
7 |
+
import tempfile
|
8 |
+
import time
|
9 |
+
import base64
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
import rembg
|
13 |
+
import spaces
|
14 |
+
import torch
|
15 |
+
from PIL import Image
|
16 |
+
from functools import partial
|
17 |
+
import io
|
18 |
+
|
19 |
+
# s3 = boto3.client(
|
20 |
+
# 's3',
|
21 |
+
# aws_access_key_id="AKIAZW3QSPMIH4RF42UA",
|
22 |
+
# aws_secret_access_key="iH8UDkDS2tMuB0GUiyq+QpM0jTxm+00mhDz0PgZz",
|
23 |
+
# region_name='us-east-1'
|
24 |
+
# )
|
25 |
+
|
26 |
+
subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
|
27 |
+
|
28 |
+
from tsr.system import TSR
|
29 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
30 |
+
|
31 |
+
|
32 |
+
HEADER = """FRAME AI"""
|
33 |
+
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = "cuda:0"
|
36 |
+
else:
|
37 |
+
device = "cpu"
|
38 |
+
|
39 |
+
model = TSR.from_pretrained(
|
40 |
+
"stabilityai/TripoSR",
|
41 |
+
config_name="config.yaml",
|
42 |
+
weight_name="model.ckpt",
|
43 |
+
)
|
44 |
+
model.renderer.set_chunk_size(131072)
|
45 |
+
model.to(device)
|
46 |
+
|
47 |
+
rembg_session = rembg.new_session()
|
48 |
+
|
49 |
+
def generate_image_from_text(pos_prompt):
|
50 |
+
# bedrock_runtime = boto3.client(region_name = 'us-east-1', service_name='bedrock-runtime')
|
51 |
+
bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = "AKIAZW3QSPMIH4RF42UA", aws_secret_access_key = "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", region_name='us-east-1')
|
52 |
+
parameters = {'text_prompts': [{'text':pos_prompt, 'weight':1},
|
53 |
+
{'text': """Blurry, unnatural, ugly, pixelated obscure, dull, artifacts, duplicate, bad quality, low resolution, cropped, out of frame, out of focus""", 'weight': -1}],
|
54 |
+
'cfg_scale': 7, 'seed': 0, 'samples': 1}
|
55 |
+
request_body = json.dumps(parameters)
|
56 |
+
response = bedrock_runtime.invoke_model(body=request_body,modelId = 'stability.stable-diffusion-xl-v1')
|
57 |
+
response_body = json.loads(response.get('body').read())
|
58 |
+
base64_image_data = base64.b64decode(response_body['artifacts'][0]['base64'])
|
59 |
+
|
60 |
+
return Image.open(io.BytesIO(base64_image_data))
|
61 |
+
|
62 |
+
def check_input_image(input_image):
|
63 |
+
if input_image is None:
|
64 |
+
raise gr.Error("No image uploaded!")
|
65 |
+
|
66 |
+
def preprocess(input_image, do_remove_background, foreground_ratio):
|
67 |
+
def fill_background(image):
|
68 |
+
image = np.array(image).astype(np.float32) / 255.0
|
69 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
70 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
71 |
+
return image
|
72 |
+
|
73 |
+
if do_remove_background:
|
74 |
+
image = input_image.convert("RGB")
|
75 |
+
image = remove_background(image, rembg_session)
|
76 |
+
image = resize_foreground(image, foreground_ratio)
|
77 |
+
image = fill_background(image)
|
78 |
+
else:
|
79 |
+
image = input_image
|
80 |
+
if image.mode == "RGBA":
|
81 |
+
image = fill_background(image)
|
82 |
+
return image
|
83 |
+
|
84 |
+
@spaces.GPU
|
85 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
86 |
+
scene_codes = model(image, device=device)
|
87 |
+
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
|
88 |
+
mesh = to_gradio_3d_orientation(mesh)
|
89 |
+
|
90 |
+
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
|
91 |
+
mesh.export(mesh_path_glb.name)
|
92 |
+
|
93 |
+
mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
|
94 |
+
mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
|
95 |
+
mesh.export(mesh_path_obj.name)
|
96 |
+
|
97 |
+
return mesh_path_obj.name, mesh_path_glb.name
|
98 |
+
|
99 |
+
def run_example(text_prompt, do_remove_background, foreground_ratio, mc_resolution):
|
100 |
+
# Step 1: Generate the image from text prompt
|
101 |
+
image_pil = generate_image_from_text(text_prompt)
|
102 |
+
|
103 |
+
# Step 2: Preprocess the image
|
104 |
+
preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
|
105 |
+
|
106 |
+
# Step 3: Generate the 3D model
|
107 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution, ["obj", "glb"])
|
108 |
+
|
109 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
110 |
+
|
111 |
+
with gr.Blocks() as demo:
|
112 |
+
gr.Markdown(HEADER)
|
113 |
+
with gr.Row(variant="panel"):
|
114 |
+
with gr.Column():
|
115 |
+
with gr.Row():
|
116 |
+
text_prompt = gr.Textbox(
|
117 |
+
label="Text Prompt",
|
118 |
+
placeholder="Enter a text prompt for image generation"
|
119 |
+
)
|
120 |
+
input_image = gr.Image(
|
121 |
+
label="Generated Image",
|
122 |
+
image_mode="RGBA",
|
123 |
+
sources="upload",
|
124 |
+
type="pil",
|
125 |
+
elem_id="content_image",
|
126 |
+
visible=False # Hidden since we generate the image from text
|
127 |
+
)
|
128 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Group():
|
131 |
+
do_remove_background = gr.Checkbox(
|
132 |
+
label="Remove Background", value=True
|
133 |
+
)
|
134 |
+
foreground_ratio = gr.Slider(
|
135 |
+
label="Foreground Ratio",
|
136 |
+
minimum=0.5,
|
137 |
+
maximum=1.0,
|
138 |
+
value=0.85,
|
139 |
+
step=0.05,
|
140 |
+
)
|
141 |
+
mc_resolution = gr.Slider(
|
142 |
+
label="Marching Cubes Resolution",
|
143 |
+
minimum=32,
|
144 |
+
maximum=320,
|
145 |
+
value=256,
|
146 |
+
step=32
|
147 |
+
)
|
148 |
+
with gr.Row():
|
149 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
150 |
+
with gr.Column():
|
151 |
+
with gr.Tab("OBJ"):
|
152 |
+
output_model_obj = gr.Model3D(
|
153 |
+
label="Output Model (OBJ Format)",
|
154 |
+
interactive=False,
|
155 |
+
)
|
156 |
+
gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
|
157 |
+
with gr.Tab("GLB"):
|
158 |
+
output_model_glb = gr.Model3D(
|
159 |
+
label="Output Model (GLB Format)",
|
160 |
+
interactive=False,
|
161 |
+
)
|
162 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
163 |
+
with gr.Row(variant="panel"):
|
164 |
+
gr.Examples(
|
165 |
+
examples=[
|
166 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
167 |
+
],
|
168 |
+
inputs=[text_prompt],
|
169 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
170 |
+
cache_examples=True,
|
171 |
+
fn=partial(run_example, do_remove_background=True, foreground_ratio=0.85, mc_resolution=256),
|
172 |
+
label="Examples",
|
173 |
+
examples_per_page=20
|
174 |
+
)
|
175 |
+
submit.click(fn=check_input_image, inputs=[text_prompt]).success(
|
176 |
+
fn=run_example,
|
177 |
+
inputs=[text_prompt, do_remove_background, foreground_ratio, mc_resolution],
|
178 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
179 |
+
)
|
180 |
+
|
181 |
+
demo.queue(max_size=10)
|
182 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
Pillow==10.1.0
|
3 |
+
einops==0.7.0
|
4 |
+
torch==2.0.1
|
5 |
+
transformers==4.35.0
|
6 |
+
trimesh==4.0.5
|
7 |
+
rembg
|
8 |
+
huggingface-hub
|
9 |
+
gradio
|
10 |
+
boto3
|