# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import Literal import torch import torch.nn as nn class WMDetectionLoss(nn.Module): """Compute the detection loss""" def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None: super().__init__() self.criterion = nn.NLLLoss() self.p_weight = p_weight self.n_weight = n_weight def forward(self, positive, negative, mask, message=None): positive = positive[:, :2, :] # b 2+nbits t -> b 2 t negative = negative[:, :2, :] # b 2+nbits t -> b 2 t # dimensionality of positive [bsz, classes=2, time_steps] # correct classes for pos = [bsz, time_steps] where all values = 1 for positive classes_shape = positive[ :, 0, : ] # same as positive or negative but dropping dim=1 pos_correct_classes = torch.ones_like(classes_shape, dtype=int) neg_correct_classes = torch.zeros_like(classes_shape, dtype=int) # taking log because network outputs softmax # NLLLoss expects a logsoftmax input positive = torch.log(positive) negative = torch.log(negative) if not torch.all(mask == 1): # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes] # mask is applied to the watermark, this basically flips the tgt class from 1 (positive) # to 0 (negative) in the correct places pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int) loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) # no need for negative class loss here since some of the watermark # is masked to negative return loss_p else: loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) loss_n = self.n_weight * self.criterion(negative, neg_correct_classes) return loss_p + loss_n class WMMbLoss(nn.Module): def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None: """ Compute the masked sample-level detection loss (https://arxiv.org/pdf/2401.17264) Args: temperature: temperature for loss computation loss_type: bce or mse between outputs and original message """ super().__init__() self.bce_with_logits = ( nn.BCEWithLogitsLoss() ) # same as Softmax + NLLLoss, but when only 1 output unit self.mse = nn.MSELoss() self.loss_type = loss_type self.temperature = temperature def forward(self, positive, negative, mask, message): """ Compute decoding loss Args: positive: outputs on watermarked samples [bsz, 2+nbits, time_steps] negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps] mask: watermark mask [bsz, 1, time_steps] message: original message [bsz, nbits] or None """ # # no use of negative at the moment # negative = negative[:, 2:, :] # b 2+nbits t -> b nbits t # negative = torch.masked_select(negative, mask) if message.size(0) == 0: return torch.tensor(0.0) positive = positive[:, 2:, :] # b 2+nbits t -> b nbits t assert ( positive.shape[-2] == message.shape[1] ), "in decoding loss: \ enc and dec don't share nbits, are you using multi-bit?" # cut last dim of positive to keep only where mask is 1 new_shape = [*positive.shape[:-1], -1] # b nbits -1 positive = torch.masked_select(positive, mask == 1).reshape(new_shape) message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2]) # b k -> b k t if self.loss_type == "bce": # in this case similar to temperature in softmax loss = self.bce_with_logits(positive / self.temperature, message.float()) elif self.loss_type == "mse": loss = self.mse(positive / self.temperature, message.float()) return loss