Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch.nn import BCEWithLogitsLoss | |
import numpy as np | |
from transformers import ViTImageProcessor, ViTForImageClassification | |
from PIL import Image | |
# Load model and feature extractor outside the function | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384') | |
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384') | |
model.to(device) | |
model.eval() | |
def get_encoder_activations(x): | |
encoder_output = model.vit(x) | |
final_activations = encoder_output.last_hidden_state[:,0,:] | |
return final_activations | |
def process_image(input_image, learning_rate, iterations, n_targets, seed): | |
if input_image is None: | |
return None | |
image = input_image.convert('RGB') | |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
pixel_values.requires_grad_(True) | |
torch.manual_seed(int(seed)) | |
random_one_logits = torch.zeros(1000) | |
random_one_logits[torch.randperm(1000)[:int(n_targets)]] = 1 | |
random_one_logits = random_one_logits.to(pixel_values.device) | |
for iteration in range(int(iterations)): | |
model.zero_grad() | |
if pixel_values.grad is not None: | |
pixel_values.grad.data.zero_() | |
final_activations = get_encoder_activations(pixel_values.to(device)) | |
logits = model.classifier(final_activations[0]).to(device) | |
original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits) | |
original_loss.backward() | |
with torch.no_grad(): | |
pixel_values.data += learning_rate * pixel_values.grad.data | |
pixel_values.data = torch.clamp(pixel_values.data, -1, 1) | |
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5 | |
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8) | |
return updated_pixel_values_np | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Number(value=1.0, minimum=0, label="Learning Rate"), | |
gr.Number(value=2, minimum=1, label="Iterations"), | |
gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"), | |
gr.Number(value=420, minimum=0, label="Seed"), | |
], | |
outputs=[gr.Image(type="numpy", label="Dreamed Image")] | |
) | |
iface.launch() | |