SoggyKiwi commited on
Commit
37ebd45
1 Parent(s): 03f7bd7
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -4,35 +4,35 @@ import numpy as np
4
  from transformers import ViTImageProcessor, ViTForImageClassification
5
  from PIL import Image
6
 
7
- def process_image(input_image, learning_rate, iterations):
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
11
- model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
12
- model.to(device)
13
- model.eval()
14
 
 
15
  def get_encoder_activations(x):
16
  encoder_output = model.vit(x)
17
  final_activations = encoder_output.last_hidden_state
18
  return final_activations
19
-
20
  image = input_image.convert('RGB')
21
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
22
- pixel_values.to(device)
23
  pixel_values.requires_grad_(True)
24
 
25
- for iteration in range(iterations.value):
26
  model.zero_grad()
27
  if pixel_values.grad is not None:
28
  pixel_values.grad.data.zero_()
29
 
30
- final_activations = get_encoder_activations(pixel_values.to('cuda'))
31
  target_sum = final_activations.sum()
32
  target_sum.backward()
33
 
34
  with torch.no_grad():
35
- pixel_values.data += learning_rate.value * pixel_values.grad.data
36
  pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
37
 
38
  updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
@@ -50,4 +50,4 @@ iface = gr.Interface(
50
  outputs=gr.Image(type="numpy", label="Processed Image")
51
  )
52
 
53
- iface.launch()
 
4
  from transformers import ViTImageProcessor, ViTForImageClassification
5
  from PIL import Image
6
 
7
+ # Load model and feature extractor outside the function
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
10
+ model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
11
+ model.to(device)
12
+ model.eval()
 
13
 
14
+ def process_image(input_image, learning_rate, iterations):
15
  def get_encoder_activations(x):
16
  encoder_output = model.vit(x)
17
  final_activations = encoder_output.last_hidden_state
18
  return final_activations
19
+
20
  image = input_image.convert('RGB')
21
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
22
+ pixel_values = pixel_values.to(device)
23
  pixel_values.requires_grad_(True)
24
 
25
+ for iteration in range(iterations):
26
  model.zero_grad()
27
  if pixel_values.grad is not None:
28
  pixel_values.grad.data.zero_()
29
 
30
+ final_activations = get_encoder_activations(pixel_values)
31
  target_sum = final_activations.sum()
32
  target_sum.backward()
33
 
34
  with torch.no_grad():
35
+ pixel_values.data += learning_rate * pixel_values.grad.data
36
  pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
37
 
38
  updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
 
50
  outputs=gr.Image(type="numpy", label="Processed Image")
51
  )
52
 
53
+ iface.launch()