freddyaboulton HF staff commited on
Commit
9d62c72
1 Parent(s): 2776201

snac_utils

Browse files
Files changed (1) hide show
  1. utils/snac_utils.py +146 -0
utils/snac_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output, device=None):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ def count_elements_between_hashes(lst):
65
+ try:
66
+ # Find the index of the first '#'
67
+ first_index = lst.index("#")
68
+ # Find the index of the second '#' after the first
69
+ second_index = lst.index("#", first_index + 1)
70
+ # Count the elements between the two indices
71
+ return second_index - first_index - 1
72
+ except ValueError:
73
+ # Handle the case where there aren't enough '#' symbols
74
+ return "List does not contain two '#' symbols"
75
+
76
+ def remove_elements_before_hash(flattened_list):
77
+ try:
78
+ # Find the index of the first '#'
79
+ first_hash_index = flattened_list.index("#")
80
+ # Return the list starting from the first '#'
81
+ return flattened_list[first_hash_index:]
82
+ except ValueError:
83
+ # Handle the case where there is no '#'
84
+ return "List does not contain the symbol '#'"
85
+
86
+ def list_to_torch_tensor(tensor1):
87
+ # Convert the list to a torch tensor
88
+ tensor = torch.tensor(tensor1)
89
+ # Reshape the tensor to have size (1, n)
90
+ tensor = tensor.unsqueeze(0)
91
+ return tensor
92
+
93
+ flattened_output = remove_elements_before_hash(flattened_output)
94
+ codes = []
95
+ tensor1 = []
96
+ tensor2 = []
97
+ tensor3 = []
98
+ tensor4 = []
99
+
100
+ n_tensors = count_elements_between_hashes(flattened_output)
101
+ if n_tensors == 7:
102
+ for i in range(0, len(flattened_output), 8):
103
+
104
+ tensor1.append(flattened_output[i + 1])
105
+ tensor2.append(flattened_output[i + 2])
106
+ tensor3.append(flattened_output[i + 3])
107
+ tensor3.append(flattened_output[i + 4])
108
+
109
+ tensor2.append(flattened_output[i + 5])
110
+ tensor3.append(flattened_output[i + 6])
111
+ tensor3.append(flattened_output[i + 7])
112
+ codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
+ ]
117
+
118
+ if n_tensors == 15:
119
+ for i in range(0, len(flattened_output), 16):
120
+
121
+ tensor1.append(flattened_output[i + 1])
122
+ tensor2.append(flattened_output[i + 2])
123
+ tensor3.append(flattened_output[i + 3])
124
+ tensor4.append(flattened_output[i + 4])
125
+ tensor4.append(flattened_output[i + 5])
126
+ tensor3.append(flattened_output[i + 6])
127
+ tensor4.append(flattened_output[i + 7])
128
+ tensor4.append(flattened_output[i + 8])
129
+
130
+ tensor2.append(flattened_output[i + 9])
131
+ tensor3.append(flattened_output[i + 10])
132
+ tensor4.append(flattened_output[i + 11])
133
+ tensor4.append(flattened_output[i + 12])
134
+ tensor3.append(flattened_output[i + 13])
135
+ tensor4.append(flattened_output[i + 14])
136
+ tensor4.append(flattened_output[i + 15])
137
+
138
+ codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
+ ]
144
+
145
+ return codes
146
+