Spaces:
Running
on
Zero
Running
on
Zero
# Adapt from Cheng An Hsieh, et. al.: https://github.com/RewardMultiverse/reward-multiverse | |
from PIL import Image | |
import io | |
import numpy as np | |
import torch.nn as nn | |
import torch | |
import torchvision | |
import albumentations as A | |
from transformers import CLIPModel, CLIPProcessor | |
# import ipdb | |
# st = ipdb.set_trace | |
def jpeg_compressibility(device): | |
def _fn(images): | |
''' | |
args: | |
images: shape NCHW | |
''' | |
org_type = images.dtype | |
if isinstance(images, torch.Tensor): | |
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | |
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC | |
transform_images_tensor = torch.Tensor(np.array(images)).to(device, dtype=org_type) | |
transform_images_tensor = (transform_images_tensor.permute(0,3,1,2) / 255).clamp(0,1) # NHWC -> NCHW | |
transform_images_pil = [Image.fromarray(image) for image in images] | |
buffers = [io.BytesIO() for _ in transform_images_pil] | |
for image, buffer in zip(transform_images_pil, buffers): | |
image.save(buffer, format="JPEG", quality=95) | |
sizes = [buffer.tell() / 1000 for buffer in buffers] | |
return np.array(sizes), transform_images_tensor | |
return _fn | |
class MLP(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(768, 512), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(128, 32), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(32, 1), | |
) | |
def forward(self, embed): | |
return self.layers(embed) | |
def jpegcompression_loss_fn(target=None, | |
grad_scale=0, | |
device=None, | |
accelerator=None, | |
torch_dtype=None, | |
reward_model_resume_from=None): | |
scorer = JpegCompressionScorer(dtype=torch_dtype, model_path=reward_model_resume_from).to(device, dtype=torch_dtype) | |
scorer.requires_grad_(False) | |
scorer.eval() | |
def loss_fn(im_pix_un): | |
if accelerator.mixed_precision == "fp16": | |
with accelerator.autocast(): | |
rewards = scorer(im_pix_un) | |
else: | |
rewards = scorer(im_pix_un) | |
if target is None: | |
loss = rewards | |
else: | |
loss = abs(rewards - target) | |
return loss * grad_scale, rewards | |
return loss_fn | |
class JpegCompressionScorer(nn.Module): | |
def __init__(self, dtype=None, model_path=None): | |
super().__init__() | |
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
self.clip.requires_grad_(False) | |
self.score_generator = MLP() | |
if model_path: | |
state_dict = torch.load(model_path) | |
self.score_generator.load_state_dict(state_dict) | |
if dtype: | |
self.dtype = dtype | |
self.target_size = (224,224) | |
self.normalize = torchvision.transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
def set_device(self, device, inference_type): | |
# self.clip.to(device, dtype = inference_type) | |
self.score_generator.to(device) # , dtype = inference_type | |
def __call__(self, images): | |
device = next(self.parameters()).device | |
im_pix = torchvision.transforms.Resize(self.target_size)(images) | |
im_pix = self.normalize(im_pix).to(images.dtype) | |
embed = self.clip.get_image_features(pixel_values=im_pix) | |
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | |
return self.score_generator(embed).squeeze(1) |