MusicLM / audiocraft /models /watermark.py
SunilGopal's picture
Upload 214 files
6a662e6 verified
raw
history blame
3.56 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from audiocraft.models.loaders import load_audioseal_models
class WMModel(ABC, nn.Module):
"""
A wrapper interface to different watermarking models for
training or evaluation purporses
"""
@abstractmethod
def get_watermark(
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
) -> torch.Tensor:
"""Get the watermark from an audio tensor and a message.
If the input message is None, a random message of
n bits {0,1} will be generated
"""
@abstractmethod
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
"""Detect the watermarks from the audio signal
Args:
x: Audio signal, size batch x frames
Returns:
tensor of size (B, 2+n, frames) where:
Detection results of shape (B, 2, frames)
Message decoding results of shape (B, n, frames)
"""
class AudioSeal(WMModel):
"""Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the
training and evaluation. The generator and detector are jointly trained
"""
def __init__(
self,
generator: nn.Module,
detector: nn.Module,
nbits: int = 0,
):
super().__init__()
self.generator = generator # type: ignore
self.detector = detector # type: ignore
# Allow to re-train an n-bit model with new 0-bit message
self.nbits = nbits if nbits else self.generator.msg_processor.nbits
def get_watermark(
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
) -> torch.Tensor:
return self.generator.get_watermark(x, message=message, sample_rate=sample_rate)
def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
"""
Detect the watermarks from the audio signal. The first two units of the output
are used for detection, the rest is used to decode the message. If the audio is
not watermarked, the message will be random.
Args:
x: Audio signal, size batch x frames
Returns
torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T).
"""
# Getting the direct decoded message from the detector
result = self.detector.detector(x) # b x 2+nbits
# hardcode softmax on 2 first units used for detection
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)
return result
def forward( # generator
self,
x: torch.Tensor,
message: tp.Optional[torch.Tensor] = None,
sample_rate: int = 16_000,
alpha: float = 1.0,
) -> torch.Tensor:
"""Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)"""
wm = self.get_watermark(x, message)
return x + alpha * wm
@staticmethod
def get_pretrained(name="base", device=None) -> WMModel:
if device is None:
if torch.cuda.device_count():
device = "cuda"
else:
device = "cpu"
return load_audioseal_models("facebook/audioseal", filename=name, device=device)