File size: 5,062 Bytes
7ee3434 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from
# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/tokenizer.py
import re
from typing import Any, Dict, List, Optional, Pattern, Union
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
class AudioTokenizer:
"""EnCodec audio tokenizer for encoding and decoding audio.
Attributes:
device: The device on which the codec model is loaded.
codec: The pretrained EnCodec model.
sample_rate: Sample rate of the model.
channels: Number of audio channels in the model.
"""
def __init__(self, device: Any = None) -> None:
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
remove_encodec_weight_norm(model)
if not device:
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda:0")
self._device = device
self.codec = model.to(device)
self.sample_rate = model.sample_rate
self.channels = model.channels
@property
def device(self):
return self._device
def encode(self, wav: torch.Tensor) -> torch.Tensor:
"""Encode the audio waveform.
Args:
wav: A tensor representing the audio waveform.
Returns:
A tensor representing the encoded audio.
"""
return self.codec.encode(wav.to(self.device))
def decode(self, frames: torch.Tensor) -> torch.Tensor:
"""Decode the encoded audio frames.
Args:
frames: A tensor representing the encoded audio frames.
Returns:
A tensor representing the decoded audio waveform.
"""
return self.codec.decode(frames)
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
"""
Tokenize the audio waveform using the given AudioTokenizer.
Args:
tokenizer: An instance of AudioTokenizer.
audio_path: Path to the audio file.
Returns:
A tensor of encoded frames from the audio.
Raises:
FileNotFoundError: If the audio file is not found.
RuntimeError: If there's an error processing the audio data.
"""
# try:
# Load and preprocess the audio waveform
wav, sr = torchaudio.load(audio_path)
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
wav = wav.unsqueeze(0)
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = tokenizer.encode(wav)
return encoded_frames
# except FileNotFoundError:
# raise FileNotFoundError(f"Audio file not found at {audio_path}")
# except Exception as e:
# raise RuntimeError(f"Error processing audio data: {e}")
def remove_encodec_weight_norm(model):
from encodec.modules import SConv1d
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
def extract_encodec_token(wav_path):
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
wav, sr = torchaudio.load(wav_path)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.unsqueeze(0)
if torch.cuda.is_available():
model = model.cuda()
wav = wav.cuda()
with torch.no_grad():
encoded_frames = model.encode(wav)
codes_ = torch.cat(
[encoded[0] for encoded in encoded_frames], dim=-1
) # [B, n_q, T]
codes = codes_.cpu().numpy()[0, :, :].T # [T, 8]
return codes
|