import gradio as gr import torch import numpy as np from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image # Load model and feature extractor outside the function device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384') model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384') model.to(device) model.eval() def process_image(input_image, learning_rate, iterations): def get_encoder_activations(x): encoder_output = model.vit(x) final_activations = encoder_output.last_hidden_state return final_activations image = input_image.convert('RGB') pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) pixel_values.requires_grad_(True) for iteration in range(iterations): model.zero_grad() if pixel_values.grad is not None: pixel_values.grad.data.zero_() final_activations = get_encoder_activations(pixel_values) target_sum = final_activations.sum() target_sum.backward() with torch.no_grad(): pixel_values.data += learning_rate * pixel_values.grad.data pixel_values.data = torch.clamp(pixel_values.data, -1, 1) updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5 updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8) return updated_pixel_values_np iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil"), gr.Number(value=0.01, label="Learning Rate"), gr.Number(value=1, label="Iterations") ], outputs=gr.Image(type="numpy", label="Processed Image") ) iface.launch()