gpt-omni commited on
Commit
e6f5492
1 Parent(s): b148265
Files changed (1) hide show
  1. snac_utils.py +143 -0
snac_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel):
25
+ audio = reconstruct_tensors(snac_tokens)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ def count_elements_between_hashes(lst):
62
+ try:
63
+ # Find the index of the first '#'
64
+ first_index = lst.index("#")
65
+ # Find the index of the second '#' after the first
66
+ second_index = lst.index("#", first_index + 1)
67
+ # Count the elements between the two indices
68
+ return second_index - first_index - 1
69
+ except ValueError:
70
+ # Handle the case where there aren't enough '#' symbols
71
+ return "List does not contain two '#' symbols"
72
+
73
+ def remove_elements_before_hash(flattened_list):
74
+ try:
75
+ # Find the index of the first '#'
76
+ first_hash_index = flattened_list.index("#")
77
+ # Return the list starting from the first '#'
78
+ return flattened_list[first_hash_index:]
79
+ except ValueError:
80
+ # Handle the case where there is no '#'
81
+ return "List does not contain the symbol '#'"
82
+
83
+ def list_to_torch_tensor(tensor1):
84
+ # Convert the list to a torch tensor
85
+ tensor = torch.tensor(tensor1)
86
+ # Reshape the tensor to have size (1, n)
87
+ tensor = tensor.unsqueeze(0)
88
+ return tensor
89
+
90
+ flattened_output = remove_elements_before_hash(flattened_output)
91
+ codes = []
92
+ tensor1 = []
93
+ tensor2 = []
94
+ tensor3 = []
95
+ tensor4 = []
96
+
97
+ n_tensors = count_elements_between_hashes(flattened_output)
98
+ if n_tensors == 7:
99
+ for i in range(0, len(flattened_output), 8):
100
+
101
+ tensor1.append(flattened_output[i + 1])
102
+ tensor2.append(flattened_output[i + 2])
103
+ tensor3.append(flattened_output[i + 3])
104
+ tensor3.append(flattened_output[i + 4])
105
+
106
+ tensor2.append(flattened_output[i + 5])
107
+ tensor3.append(flattened_output[i + 6])
108
+ tensor3.append(flattened_output[i + 7])
109
+ codes = [
110
+ list_to_torch_tensor(tensor1).cuda(),
111
+ list_to_torch_tensor(tensor2).cuda(),
112
+ list_to_torch_tensor(tensor3).cuda(),
113
+ ]
114
+
115
+ if n_tensors == 15:
116
+ for i in range(0, len(flattened_output), 16):
117
+
118
+ tensor1.append(flattened_output[i + 1])
119
+ tensor2.append(flattened_output[i + 2])
120
+ tensor3.append(flattened_output[i + 3])
121
+ tensor4.append(flattened_output[i + 4])
122
+ tensor4.append(flattened_output[i + 5])
123
+ tensor3.append(flattened_output[i + 6])
124
+ tensor4.append(flattened_output[i + 7])
125
+ tensor4.append(flattened_output[i + 8])
126
+
127
+ tensor2.append(flattened_output[i + 9])
128
+ tensor3.append(flattened_output[i + 10])
129
+ tensor4.append(flattened_output[i + 11])
130
+ tensor4.append(flattened_output[i + 12])
131
+ tensor3.append(flattened_output[i + 13])
132
+ tensor4.append(flattened_output[i + 14])
133
+ tensor4.append(flattened_output[i + 15])
134
+
135
+ codes = [
136
+ list_to_torch_tensor(tensor1).cuda(),
137
+ list_to_torch_tensor(tensor2).cuda(),
138
+ list_to_torch_tensor(tensor3).cuda(),
139
+ list_to_torch_tensor(tensor4).cuda(),
140
+ ]
141
+
142
+ return codes
143
+