ascarlettvfx commited on
Commit
2addbd2
1 Parent(s): c9aa520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from gradio_client import Client, file
3
  from PIL import Image
4
  import numpy as np
5
  import io
@@ -8,20 +8,24 @@ import tempfile
8
  def process_image(image):
9
  client = Client("prs-eth/marigold")
10
 
11
- # Save the uploaded image temporarily
 
 
 
 
12
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpeg") as tmp:
13
  image.save(tmp, format='JPEG')
14
- tmp_path = tmp.name
15
 
16
  # Call the API with necessary parameters
17
  result = client.predict(
18
- file(tmp_path), # filepath for 'Input Image' Image component
19
  20, # Ensemble size
20
  10, # Number of denoising steps
21
  "0", # Processing resolution
22
- file(tmp_path), # Placeholder for 'Predicted depth (16-bit)'
23
- file(tmp_path), # Placeholder for 'Predicted depth (32-bit)'
24
- file(tmp_path), # Placeholder for 'Predicted depth (red-near, blue-far)'
25
  0, # Relative position of the near plane
26
  0, # Relative position of the far plane
27
  0, # Embossing level
@@ -32,7 +36,7 @@ def process_image(image):
32
 
33
  # Handle the returned file path for the depth image
34
  if result and 'depth_outputs' in result:
35
- depth_image_path = result['depth_outputs'][0] # Assuming the result is a local path to the image
36
  depth_image = Image.open(depth_image_path)
37
  depth_image.load() # Ensure the image is loaded completely
38
  return depth_image
 
1
  import gradio as gr
2
+ from gradio_client import Client, handle_file
3
  from PIL import Image
4
  import numpy as np
5
  import io
 
8
  def process_image(image):
9
  client = Client("prs-eth/marigold")
10
 
11
+ # Convert numpy array to PIL Image if needed
12
+ if isinstance(image, np.ndarray):
13
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
14
+
15
+ # Save the PIL Image to a temporary file
16
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpeg") as tmp:
17
  image.save(tmp, format='JPEG')
18
+ tmp_path = tmp.name # Get the file path
19
 
20
  # Call the API with necessary parameters
21
  result = client.predict(
22
+ handle_file(tmp_path), # filepath for 'Input Image' Image component
23
  20, # Ensemble size
24
  10, # Number of denoising steps
25
  "0", # Processing resolution
26
+ handle_file(tmp_path), # Placeholder for 'Predicted depth (16-bit)'
27
+ handle_file(tmp_path), # Placeholder for 'Predicted depth (32-bit)'
28
+ handle_file(tmp_path), # Placeholder for 'Predicted depth (red-near, blue-far)'
29
  0, # Relative position of the near plane
30
  0, # Relative position of the far plane
31
  0, # Embossing level
 
36
 
37
  # Handle the returned file path for the depth image
38
  if result and 'depth_outputs' in result:
39
+ depth_image_path = result['depth_outputs'][0]
40
  depth_image = Image.open(depth_image_path)
41
  depth_image.load() # Ensure the image is loaded completely
42
  return depth_image