Spaces:
Build error
Build error
import base64 | |
import io | |
from unittest import TestCase | |
import numpy as np | |
import numpy.testing | |
import soundfile | |
from scipy.signal import resample | |
from voicevox_engine.utility import ConnectBase64WavesException, connect_base64_waves | |
def generate_sine_wave_ndarray( | |
seconds: float, samplerate: int, frequency: float | |
) -> np.ndarray: | |
x = np.linspace(0, seconds, int(seconds * samplerate), endpoint=False) | |
wave = np.sin(2 * np.pi * frequency * x).astype(np.float32) | |
return wave | |
def encode_bytes(wave_ndarray: np.ndarray, samplerate: int) -> bytes: | |
wave_bio = io.BytesIO() | |
soundfile.write( | |
file=wave_bio, | |
data=wave_ndarray, | |
samplerate=samplerate, | |
format="WAV", | |
subtype="FLOAT", | |
) | |
wave_bio.seek(0) | |
return wave_bio.getvalue() | |
def generate_sine_wave_bytes( | |
seconds: float, samplerate: int, frequency: float | |
) -> bytes: | |
wave_ndarray = generate_sine_wave_ndarray(seconds, samplerate, frequency) | |
return encode_bytes(wave_ndarray, samplerate) | |
def encode_base64(wave_bytes: bytes) -> str: | |
return base64.standard_b64encode(wave_bytes).decode("utf-8") | |
def generate_sine_wave_base64(seconds: float, samplerate: int, frequency: float) -> str: | |
wave_bytes = generate_sine_wave_bytes(seconds, samplerate, frequency) | |
wave_base64 = encode_base64(wave_bytes) | |
return wave_base64 | |
class TestConnectBase64Waves(TestCase): | |
def test_connect(self): | |
samplerate = 1000 | |
wave = generate_sine_wave_ndarray( | |
seconds=2, samplerate=samplerate, frequency=10 | |
) | |
wave_base64 = encode_base64(encode_bytes(wave, samplerate=samplerate)) | |
wave_x2_ref = np.concatenate([wave, wave]) | |
wave_x2, _ = connect_base64_waves(waves=[wave_base64, wave_base64]) | |
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) | |
self.assertTrue((wave_x2_ref == wave_x2).all()) | |
def test_no_wave_error(self): | |
self.assertRaises(ConnectBase64WavesException, connect_base64_waves, waves=[]) | |
def test_invalid_base64_error(self): | |
wave_1000hz = generate_sine_wave_base64( | |
seconds=2, samplerate=1000, frequency=10 | |
) | |
wave_1000hz_broken = wave_1000hz[1:] # remove head 1 char | |
self.assertRaises( | |
ConnectBase64WavesException, | |
connect_base64_waves, | |
waves=[ | |
wave_1000hz_broken, | |
], | |
) | |
def test_invalid_wave_file_error(self): | |
wave_1000hz = generate_sine_wave_bytes(seconds=2, samplerate=1000, frequency=10) | |
wave_1000hz_broken_bytes = wave_1000hz[1:] # remove head 1 byte | |
wave_1000hz_broken = encode_base64(wave_1000hz_broken_bytes) | |
self.assertRaises( | |
ConnectBase64WavesException, | |
connect_base64_waves, | |
waves=[ | |
wave_1000hz_broken, | |
], | |
) | |
def test_different_frequency(self): | |
wave_24000hz = generate_sine_wave_ndarray( | |
seconds=1, samplerate=24000, frequency=10 | |
) | |
wave_1000hz = generate_sine_wave_ndarray( | |
seconds=2, samplerate=1000, frequency=10 | |
) | |
wave_24000_base64 = encode_base64(encode_bytes(wave_24000hz, samplerate=24000)) | |
wave_1000_base64 = encode_base64(encode_bytes(wave_1000hz, samplerate=1000)) | |
wave_1000hz_to2400hz = resample(wave_1000hz, 24000 * len(wave_1000hz) // 1000) | |
wave_x2_ref = np.concatenate([wave_24000hz, wave_1000hz_to2400hz]) | |
wave_x2, _ = connect_base64_waves(waves=[wave_24000_base64, wave_1000_base64]) | |
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) | |
numpy.testing.assert_array_almost_equal(wave_x2_ref, wave_x2) | |
def test_different_channels(self): | |
wave_1000hz = generate_sine_wave_ndarray( | |
seconds=2, samplerate=1000, frequency=10 | |
) | |
wave_2ch_1000hz = np.array([wave_1000hz, wave_1000hz]).T | |
wave_1ch_base64 = encode_base64(encode_bytes(wave_1000hz, samplerate=1000)) | |
wave_2ch_base64 = encode_base64(encode_bytes(wave_2ch_1000hz, samplerate=1000)) | |
wave_x2_ref = np.concatenate([wave_2ch_1000hz, wave_2ch_1000hz]) | |
wave_x2, _ = connect_base64_waves(waves=[wave_1ch_base64, wave_2ch_base64]) | |
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) | |
self.assertTrue((wave_x2_ref == wave_x2).all()) | |