Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import ViTImageProcessor, ViTForImageClassification | |
from PIL import Image | |
# Load model and feature extractor outside the function | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384') | |
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384') | |
model.to(device) | |
model.eval() | |
def process_image(input_image, learning_rate, iterations): | |
def get_encoder_activations(x): | |
encoder_output = model.vit(x) | |
final_activations = encoder_output.last_hidden_state | |
return final_activations | |
image = input_image.convert('RGB') | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
pixel_values.requires_grad_(True) | |
for iteration in range(iterations): | |
model.zero_grad() | |
if pixel_values.grad is not None: | |
pixel_values.grad.data.zero_() | |
final_activations = get_encoder_activations(pixel_values) | |
target_sum = final_activations.sum() | |
target_sum.backward() | |
with torch.no_grad(): | |
pixel_values.data += learning_rate * pixel_values.grad.data | |
pixel_values.data = torch.clamp(pixel_values.data, -1, 1) | |
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5 | |
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8) | |
return updated_pixel_values_np | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Number(value=0.01, label="Learning Rate"), | |
gr.Number(value=1, label="Iterations") | |
], | |
outputs=gr.Image(type="numpy", label="Processed Image") | |
) | |
iface.launch() | |