Spaces:
Sleeping
Sleeping
import torch | |
import time | |
import numpy as np | |
class SnacConfig: | |
audio_vocab_size = 4096 | |
padded_vocab_size = 4160 | |
end_of_audio = 4097 | |
snac_config = SnacConfig() | |
def get_time_str(): | |
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime()) | |
return time_str | |
def layershift(input_id, layer, stride=4160, shift=152000): | |
return input_id + shift + layer * stride | |
def generate_audio_data(snac_tokens, snacmodel, device=None): | |
audio = reconstruct_tensors(snac_tokens, device) | |
with torch.inference_mode(): | |
audio_hat = snacmodel.decode(audio) | |
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0 | |
audio_data = audio_data.astype(np.int16) | |
audio_data = audio_data.tobytes() | |
return audio_data | |
def get_snac(list_output, index, nums_generate): | |
snac = [] | |
start = index | |
for i in range(nums_generate): | |
snac.append("#") | |
for j in range(7): | |
snac.append(list_output[j][start - nums_generate - 5 + j + i]) | |
return snac | |
def reconscruct_snac(output_list): | |
if len(output_list) == 8: | |
output_list = output_list[:-1] | |
output = [] | |
for i in range(7): | |
output_list[i] = output_list[i][i + 1 :] | |
for i in range(len(output_list[-1])): | |
output.append("#") | |
for j in range(7): | |
output.append(output_list[j][i]) | |
return output | |
def reconstruct_tensors(flattened_output, device=None): | |
"""Reconstructs the list of tensors from the flattened output.""" | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def count_elements_between_hashes(lst): | |
try: | |
# Find the index of the first '#' | |
first_index = lst.index("#") | |
# Find the index of the second '#' after the first | |
second_index = lst.index("#", first_index + 1) | |
# Count the elements between the two indices | |
return second_index - first_index - 1 | |
except ValueError: | |
# Handle the case where there aren't enough '#' symbols | |
return "List does not contain two '#' symbols" | |
def remove_elements_before_hash(flattened_list): | |
try: | |
# Find the index of the first '#' | |
first_hash_index = flattened_list.index("#") | |
# Return the list starting from the first '#' | |
return flattened_list[first_hash_index:] | |
except ValueError: | |
# Handle the case where there is no '#' | |
return "List does not contain the symbol '#'" | |
def list_to_torch_tensor(tensor1): | |
# Convert the list to a torch tensor | |
tensor = torch.tensor(tensor1) | |
# Reshape the tensor to have size (1, n) | |
tensor = tensor.unsqueeze(0) | |
return tensor | |
flattened_output = remove_elements_before_hash(flattened_output) | |
codes = [] | |
tensor1 = [] | |
tensor2 = [] | |
tensor3 = [] | |
tensor4 = [] | |
n_tensors = count_elements_between_hashes(flattened_output) | |
if n_tensors == 7: | |
for i in range(0, len(flattened_output), 8): | |
tensor1.append(flattened_output[i + 1]) | |
tensor2.append(flattened_output[i + 2]) | |
tensor3.append(flattened_output[i + 3]) | |
tensor3.append(flattened_output[i + 4]) | |
tensor2.append(flattened_output[i + 5]) | |
tensor3.append(flattened_output[i + 6]) | |
tensor3.append(flattened_output[i + 7]) | |
codes = [ | |
list_to_torch_tensor(tensor1).to(device), | |
list_to_torch_tensor(tensor2).to(device), | |
list_to_torch_tensor(tensor3).to(device), | |
] | |
if n_tensors == 15: | |
for i in range(0, len(flattened_output), 16): | |
tensor1.append(flattened_output[i + 1]) | |
tensor2.append(flattened_output[i + 2]) | |
tensor3.append(flattened_output[i + 3]) | |
tensor4.append(flattened_output[i + 4]) | |
tensor4.append(flattened_output[i + 5]) | |
tensor3.append(flattened_output[i + 6]) | |
tensor4.append(flattened_output[i + 7]) | |
tensor4.append(flattened_output[i + 8]) | |
tensor2.append(flattened_output[i + 9]) | |
tensor3.append(flattened_output[i + 10]) | |
tensor4.append(flattened_output[i + 11]) | |
tensor4.append(flattened_output[i + 12]) | |
tensor3.append(flattened_output[i + 13]) | |
tensor4.append(flattened_output[i + 14]) | |
tensor4.append(flattened_output[i + 15]) | |
codes = [ | |
list_to_torch_tensor(tensor1).to(device), | |
list_to_torch_tensor(tensor2).to(device), | |
list_to_torch_tensor(tensor3).to(device), | |
list_to_torch_tensor(tensor4).to(device), | |
] | |
return codes | |