Spaces:
Sleeping
Sleeping
File size: 2,502 Bytes
b7ebb88 03f7bd7 f93986b 03f7bd7 b7ebb88 37ebd45 03f7bd7 f93fa3d f93986b 07b1c90 353541c 37ebd45 03f7bd7 37ebd45 03f7bd7 a4244e1 f93986b 3c13f2b f93986b a4244e1 5c39195 03f7bd7 fa12e38 f93986b fa12e38 f93986b f93fa3d f93986b 03f7bd7 37ebd45 03f7bd7 b7ebb88 68fa56c 03f7bd7 f93986b b222813 68fa56c 93f9b8b b7ebb88 37ebd45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import torch
from torch.nn import BCEWithLogitsLoss
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()
def get_encoder_activations(x):
encoder_output = model.vit(x)
final_activations = encoder_output.last_hidden_state[:,0,:]
return final_activations
def process_image(input_image, learning_rate, iterations, n_targets, seed):
if input_image is None:
return None
image = input_image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
pixel_values.requires_grad_(True)
torch.manual_seed(int(seed))
random_one_logits = torch.zeros(1000)
random_one_logits[torch.randperm(1000)[:int(n_targets)]] = 1
random_one_logits = random_one_logits.to(pixel_values.device)
for iteration in range(int(iterations)):
model.zero_grad()
if pixel_values.grad is not None:
pixel_values.grad.data.zero_()
final_activations = get_encoder_activations(pixel_values.to(device))
logits = model.classifier(final_activations[0]).to(device)
original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
original_loss.backward()
with torch.no_grad():
pixel_values.data += learning_rate * pixel_values.grad.data
pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
return updated_pixel_values_np
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil"),
gr.Number(value=1.0, minimum=0, label="Learning Rate"),
gr.Number(value=2, minimum=1, label="Iterations"),
gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
gr.Number(value=420, minimum=0, label="Seed"),
],
outputs=[gr.Image(type="numpy", label="Dreamed Image")]
)
iface.launch()
|