imatag-vch's picture
Upload ResNetForZeroBitWatermarkDetection
d5695ca
raw
history blame
2.17 kB
import torch
from torch import nn
from transformers import ResNetPreTrainedModel
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
from transformers.image_processing_utils import BaseImageProcessor
from transformers import ResNetConfig, ResNetModel
from typing import Optional
class ResNetForZeroBitWatermarkDetection(ResNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.resnet = ResNetModel(config)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(config.hidden_sizes[-1], 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 1))
self.register_buffer('beta', torch.tensor([1.0]))
# initialize weights and apply final processing
self.post_init()
# TODO docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> ImageClassifierOutputWithNoAttention:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
x = self.classifier(pooled_output)
# generalized-Gaussian recalibration, centering and scaling is already included in last linear layer
x = 0.5 + torch.sign(x) * 0.5 * torch.special.gammainc(1 / self.beta, torch.abs(x)**self.beta)
# Laplacian calibration, centering and scaling is already included in last linear layer
# if beta==1
#x = 0.5 + torch.sign(x) * 0.5 * (1 - torch.exp(-torch.abs(x))) # laplacian
logits = torch.log(x) - torch.log1p(-x)
loss = None
if not return_dict:
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)