reedmayhew commited on
Commit
13a4c81
1 Parent(s): 0782bc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -35
app.py CHANGED
@@ -22,28 +22,28 @@ def resize_image(image, max_size=2048):
22
 
23
  # Function to upscale an image using Swin2SR
24
  def upscale_image(image, model, processor, device):
25
- # Convert the image to RGB format
26
- image = image.convert("RGB")
27
- # Process the image for the model
28
- inputs = processor(image, return_tensors="pt")
29
- # Move inputs to the same device as model
30
- inputs = {k: v.to(device) for k, v in inputs.items()}
31
- # Perform inference (upscale)
32
- with torch.no_grad():
33
- outputs = model(**inputs)
34
- # Move output back to CPU for further processing
35
- output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
36
- output = np.moveaxis(output, source=0, destination=-1)
37
- output_image = (output * 255.0).round().astype(np.uint8) # Convert from float32 to uint8
38
- # Remove 32 pixels from the bottom and right of the image
39
- output_image = output_image[:-32, :-32]
40
- return Image.fromarray(output_image)
 
 
 
41
 
42
  @spaces.GPU
43
  def main(image, model_choice, save_as_jpg=True):
44
- # Check if GPU is available and set the device accordingly
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
  # Resize the input image
48
  image = resize_image(image)
49
 
@@ -56,24 +56,33 @@ def main(image, model_choice, save_as_jpg=True):
56
  # Load the selected Swin2SR model and processor for 4x upscaling
57
  processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
58
  model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice])
59
- # Move the model to the device (GPU or CPU)
60
- model.to(device)
61
-
62
- # Upscale the image
63
- upscaled_image = upscale_image(image, model, processor, device)
64
 
65
- if save_as_jpg:
66
- # Save the upscaled image as JPG with 98% compression
67
- upscaled_image.save("upscaled_image.jpg", quality=98)
68
- return "upscaled_image.jpg"
69
- else:
70
- # Save the upscaled image as PNG
71
- upscaled_image.save("upscaled_image.png")
72
- return "upscaled_image.png"
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Gradio interface
75
  def gradio_interface(image, model_choice, save_as_jpg):
76
- return main(image, model_choice, save_as_jpg)
 
 
 
77
 
78
  # Create a Gradio interface
79
  interface = gr.Interface(
@@ -87,9 +96,12 @@ interface = gr.Interface(
87
  ),
88
  gr.Checkbox(value=True, label="Save as JPEG"),
89
  ],
90
- outputs=gr.File(label="Download Upscaled Image"),
 
 
 
91
  title="Image Upscaler",
92
- description="Upload an image, select a model, upscale it, and download the new image. Images larger than 2048x2048 will be resized while maintaining aspect ratio.",
93
  )
94
 
95
  # Launch the interface
 
22
 
23
  # Function to upscale an image using Swin2SR
24
  def upscale_image(image, model, processor, device):
25
+ try:
26
+ # Convert the image to RGB format
27
+ image = image.convert("RGB")
28
+ # Process the image for the model
29
+ inputs = processor(image, return_tensors="pt")
30
+ # Move inputs to the same device as model
31
+ inputs = {k: v.to(device) for k, v in inputs.items()}
32
+ # Perform inference (upscale)
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+ # Move output back to CPU for further processing
36
+ output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
37
+ output = np.moveaxis(output, source=0, destination=-1)
38
+ output_image = (output * 255.0).round().astype(np.uint8) # Convert from float32 to uint8
39
+ # Remove 32 pixels from the bottom and right of the image
40
+ output_image = output_image[:-32, :-32]
41
+ return Image.fromarray(output_image), None
42
+ except RuntimeError as e:
43
+ return None, str(e)
44
 
45
  @spaces.GPU
46
  def main(image, model_choice, save_as_jpg=True):
 
 
 
47
  # Resize the input image
48
  image = resize_image(image)
49
 
 
56
  # Load the selected Swin2SR model and processor for 4x upscaling
57
  processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
58
  model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice])
 
 
 
 
 
59
 
60
+ # Try GPU first, fallback to CPU if there's an error
61
+ for device in [torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.device("cpu")]:
62
+ model.to(device)
63
+ upscaled_image, error = upscale_image(image, model, processor, device)
64
+
65
+ if upscaled_image is not None:
66
+ if save_as_jpg:
67
+ # Save the upscaled image as JPG with 98% compression
68
+ upscaled_image.save("upscaled_image.jpg", quality=98)
69
+ return "upscaled_image.jpg"
70
+ else:
71
+ # Save the upscaled image as PNG
72
+ upscaled_image.save("upscaled_image.png")
73
+ return "upscaled_image.png"
74
+
75
+ if device.type == "cpu":
76
+ return f"Error: Unable to process the image. {error}"
77
+
78
+ return "Error: Unable to process the image on both GPU and CPU."
79
 
80
  # Gradio interface
81
  def gradio_interface(image, model_choice, save_as_jpg):
82
+ result = main(image, model_choice, save_as_jpg)
83
+ if result.startswith("Error:"):
84
+ return gr.update(value=None), result
85
+ return result, None
86
 
87
  # Create a Gradio interface
88
  interface = gr.Interface(
 
96
  ),
97
  gr.Checkbox(value=True, label="Save as JPEG"),
98
  ],
99
+ outputs=[
100
+ gr.File(label="Download Upscaled Image"),
101
+ gr.Textbox(label="Error Message", visible=True)
102
+ ],
103
  title="Image Upscaler",
104
+ description="Upload an image, select a model, upscale it, and download the new image. Images larger than 2048x2048 will be resized while maintaining aspect ratio. If GPU processing fails, it will attempt to process on CPU.",
105
  )
106
 
107
  # Launch the interface