reedmayhew commited on
Commit
4a66938
1 Parent(s): f1ee166

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -10,41 +10,38 @@ import spaces
10
  def upscale_image(image, model, processor, device):
11
  # Convert the image to RGB format
12
  image = image.convert("RGB")
13
-
14
  # Process the image for the model
15
  inputs = processor(image, return_tensors="pt")
16
-
17
  # Move inputs to the same device as model
18
  inputs = {k: v.to(device) for k, v in inputs.items()}
19
-
20
  # Perform inference (upscale)
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
-
24
  # Move output back to CPU for further processing
25
  output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
26
  output = np.moveaxis(output, source=0, destination=-1)
27
  output_image = (output * 255.0).round().astype(np.uint8) # Convert from float32 to uint8
28
-
29
  # Remove 32 pixels from the bottom and right of the image
30
  output_image = output_image[:-32, :-32]
31
-
32
  return Image.fromarray(output_image)
33
 
34
  @spaces.GPU
35
- def main(image, save_as_jpg=True):
36
  # Check if GPU is available and set the device accordingly
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
- realworld_model = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
40
-
41
- # Load the Swin2SR model and processor for 4x upscaling
42
- processor = AutoImageProcessor.from_pretrained(realworld_model)
43
- model = Swin2SRForImageSuperResolution.from_pretrained(realworld_model)
44
-
 
 
 
45
  # Move the model to the device (GPU or CPU)
46
  model.to(device)
47
-
48
  # Upscale the image
49
  upscaled_image = upscale_image(image, model, processor, device)
50
 
@@ -58,19 +55,24 @@ def main(image, save_as_jpg=True):
58
  return "upscaled_image.png"
59
 
60
  # Gradio interface
61
- def gradio_interface(image, save_as_jpg):
62
- return main(image, save_as_jpg)
63
 
64
  # Create a Gradio interface
65
  interface = gr.Interface(
66
  fn=gradio_interface,
67
  inputs=[
68
  gr.Image(type="pil", label="Upload Image"),
 
 
 
 
 
69
  gr.Checkbox(value=True, label="Save as JPEG"),
70
  ],
71
  outputs=gr.File(label="Download Upscaled Image"),
72
  title="Image Upscaler",
73
- description="Upload an image, upscale it, and download the new image.",
74
  )
75
 
76
  # Launch the interface
 
10
  def upscale_image(image, model, processor, device):
11
  # Convert the image to RGB format
12
  image = image.convert("RGB")
 
13
  # Process the image for the model
14
  inputs = processor(image, return_tensors="pt")
 
15
  # Move inputs to the same device as model
16
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
17
  # Perform inference (upscale)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
 
20
  # Move output back to CPU for further processing
21
  output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
22
  output = np.moveaxis(output, source=0, destination=-1)
23
  output_image = (output * 255.0).round().astype(np.uint8) # Convert from float32 to uint8
 
24
  # Remove 32 pixels from the bottom and right of the image
25
  output_image = output_image[:-32, :-32]
 
26
  return Image.fromarray(output_image)
27
 
28
  @spaces.GPU
29
+ def main(image, model_choice, save_as_jpg=True):
30
  # Check if GPU is available and set the device accordingly
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+ # Define model paths
34
+ model_paths = {
35
+ "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
36
+ "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
37
+ }
38
+
39
+ # Load the selected Swin2SR model and processor for 4x upscaling
40
+ processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
41
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice])
42
  # Move the model to the device (GPU or CPU)
43
  model.to(device)
44
+
45
  # Upscale the image
46
  upscaled_image = upscale_image(image, model, processor, device)
47
 
 
55
  return "upscaled_image.png"
56
 
57
  # Gradio interface
58
+ def gradio_interface(image, model_choice, save_as_jpg):
59
+ return main(image, model_choice, save_as_jpg)
60
 
61
  # Create a Gradio interface
62
  interface = gr.Interface(
63
  fn=gradio_interface,
64
  inputs=[
65
  gr.Image(type="pil", label="Upload Image"),
66
+ gr.Dropdown(
67
+ choices=["PSNR Match (Recommended)", "Pixel Perfect"],
68
+ label="Select Model",
69
+ value="PSNR Match (Recommended)"
70
+ ),
71
  gr.Checkbox(value=True, label="Save as JPEG"),
72
  ],
73
  outputs=gr.File(label="Download Upscaled Image"),
74
  title="Image Upscaler",
75
+ description="Upload an image, select a model, upscale it, and download the new image.",
76
  )
77
 
78
  # Launch the interface