IAMTFRMZA commited on
Commit
b197a35
1 Parent(s): 5bc266c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -73
app.py CHANGED
@@ -9,48 +9,45 @@ import rembg
9
  import torch
10
  from PIL import Image
11
  from functools import partial
12
- from serpapi import GoogleSearch
13
- import requests
14
- from io import BytesIO
15
- import matplotlib.pyplot as plt
16
 
17
  from tsr.system import TSR
18
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
19
 
20
- # Set your SerpApi key here
21
- SERPAPI_KEY = "3a786d94adb1d9739bb3a877b05dae35d231917f02ad89f74adfa014b567af3f"
22
 
23
  HEADER = """
24
  **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
25
-
26
  **Tips:**
27
  1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
28
  2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
29
  """
30
 
31
- def get_motorcycle_image(make, model):
32
- params = {
33
- "api_key": SERPAPI_KEY,
34
- "engine": "google",
35
- "q": f"{make} {model} motorcycle product photo",
36
- "tbm": "isch"
37
- }
38
-
39
- search = GoogleSearch(params)
40
- results = search.get_dict()
41
- if "images_results" in results:
42
- first_image = results["images_results"][0]
43
- image_url = first_image.get("original")
44
- if image_url:
45
- image_response = requests.get(image_url)
46
- image = Image.open(BytesIO(image_response.content))
47
- return image
48
- else:
49
- print("Image URL not found in results.")
50
- return None
51
- else:
52
- print("No image results found.")
53
- return None
 
 
 
54
 
55
  def preprocess(input_image, do_remove_background, foreground_ratio):
56
  def fill_background(image):
@@ -70,6 +67,7 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
70
  image = fill_background(image)
71
  return image
72
 
 
73
  def generate(image):
74
  scene_codes = model(image, device=device)
75
  mesh = model.extract_mesh(scene_codes)[0]
@@ -80,52 +78,23 @@ def generate(image):
80
  mesh.export(mesh_path2.name)
81
  return mesh_path.name, mesh_path2.name
82
 
83
- def run_example(make, model):
84
- image = get_motorcycle_image(make, model)
85
- if image:
86
- # Save the image
87
- input_image_path = '/content/motorcycle.jpg'
88
- image.save(input_image_path)
89
-
90
- # Load the image
91
- img = Image.open(input_image_path)
92
- output_image_path = '/content/motorcyclebg.png'
93
- img_no_bg = rembg_remove(img)
94
- img_no_bg.save(output_image_path)
95
-
96
- # Preprocess and generate 3D model
97
- preprocessed = preprocess(img_no_bg, False, 0.9)
98
- mesh_name, mesh_name2 = generate(preprocessed)
99
- return preprocessed, mesh_name, mesh_name2
100
- else:
101
- raise gr.Error("Image could not be fetched.")
102
-
103
- if torch.cuda.is_available():
104
- device = "cuda:0"
105
- else:
106
- device = "cpu"
107
-
108
- d = os.environ.get("DEVICE", None)
109
- if d != None:
110
- device = d
111
-
112
- model = TSR.from_pretrained(
113
- "stabilityai/TripoSR",
114
- config_name="config.yaml",
115
- weight_name="model.ckpt",
116
- )
117
- model.renderer.set_chunk_size(131072)
118
- model.to(device)
119
-
120
- rembg_session = rembg.new_session()
121
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown(HEADER)
124
  with gr.Row(variant="panel"):
125
  with gr.Column():
126
  with gr.Row():
127
- make_input = gr.Textbox(label="Motorcycle Make", placeholder="Enter motorcycle make")
128
- model_input = gr.Textbox(label="Motorcycle Model", placeholder="Enter motorcycle model")
 
 
 
 
 
129
  processed_image = gr.Image(label="Processed Image", interactive=False)
130
  with gr.Row():
131
  with gr.Group():
@@ -152,8 +121,27 @@ with gr.Blocks() as demo:
152
  label="Output Model",
153
  interactive=False,
154
  )
155
- submit.click(fn=run_example, inputs=[make_input, model_input], outputs=[processed_image, output_model, output_model2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  demo.queue(max_size=10)
158
- demo.launch()
159
-
 
9
  import torch
10
  from PIL import Image
11
  from functools import partial
 
 
 
 
12
 
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
+ #HF_TOKEN = os.getenv("HF_TOKEN")
 
17
 
18
  HEADER = """
19
  **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
 
20
  **Tips:**
21
  1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
22
  2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
23
  """
24
 
25
+
26
+ if torch.cuda.is_available():
27
+ device = "cuda:0"
28
+ else:
29
+ device = "cpu"
30
+
31
+ d = os.environ.get("DEVICE", None)
32
+ if d != None:
33
+ device = d
34
+
35
+ model = TSR.from_pretrained(
36
+ "stabilityai/TripoSR",
37
+ config_name="config.yaml",
38
+ weight_name="model.ckpt",
39
+ # token=HF_TOKEN
40
+ )
41
+ model.renderer.set_chunk_size(131072)
42
+ model.to(device)
43
+
44
+ rembg_session = rembg.new_session()
45
+
46
+
47
+ def check_input_image(input_image):
48
+ if input_image is None:
49
+ raise gr.Error("No image uploaded!")
50
+
51
 
52
  def preprocess(input_image, do_remove_background, foreground_ratio):
53
  def fill_background(image):
 
67
  image = fill_background(image)
68
  return image
69
 
70
+
71
  def generate(image):
72
  scene_codes = model(image, device=device)
73
  mesh = model.extract_mesh(scene_codes)[0]
 
78
  mesh.export(mesh_path2.name)
79
  return mesh_path.name, mesh_path2.name
80
 
81
+ def run_example(image_pil):
82
+ preprocessed = preprocess(image_pil, False, 0.9)
83
+ mesh_name, mesn_name2 = generate(preprocessed)
84
+ return preprocessed, mesh_name, mesh_name2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  with gr.Blocks() as demo:
87
  gr.Markdown(HEADER)
88
  with gr.Row(variant="panel"):
89
  with gr.Column():
90
  with gr.Row():
91
+ input_image = gr.Image(
92
+ label="Input Image",
93
+ image_mode="RGBA",
94
+ sources="upload",
95
+ type="pil",
96
+ elem_id="content_image",
97
+ )
98
  processed_image = gr.Image(label="Processed Image", interactive=False)
99
  with gr.Row():
100
  with gr.Group():
 
121
  label="Output Model",
122
  interactive=False,
123
  )
124
+ with gr.Row(variant="panel"):
125
+ gr.Examples(
126
+ examples=[
127
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
128
+ ],
129
+ inputs=[input_image],
130
+ outputs=[processed_image, output_model, output_model2],
131
+ #cache_examples=True,
132
+ fn=partial(run_example),
133
+ label="Examples",
134
+ examples_per_page=20
135
+ )
136
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
137
+ fn=preprocess,
138
+ inputs=[input_image, do_remove_background, foreground_ratio],
139
+ outputs=[processed_image],
140
+ ).success(
141
+ fn=generate,
142
+ inputs=[processed_image],
143
+ outputs=[output_model, output_model2],
144
+ )
145
 
146
  demo.queue(max_size=10)
147
+ demo.launch()