|
import base64 |
|
import io |
|
from unittest import TestCase |
|
|
|
import numpy as np |
|
import numpy.testing |
|
import soundfile |
|
from scipy.signal import resample |
|
|
|
from voicevox_engine.utility import ConnectBase64WavesException, connect_base64_waves |
|
|
|
|
|
def generate_sine_wave_ndarray( |
|
seconds: float, samplerate: int, frequency: float |
|
) -> np.ndarray: |
|
x = np.linspace(0, seconds, int(seconds * samplerate), endpoint=False) |
|
wave = np.sin(2 * np.pi * frequency * x).astype(np.float32) |
|
|
|
return wave |
|
|
|
|
|
def encode_bytes(wave_ndarray: np.ndarray, samplerate: int) -> bytes: |
|
wave_bio = io.BytesIO() |
|
soundfile.write( |
|
file=wave_bio, |
|
data=wave_ndarray, |
|
samplerate=samplerate, |
|
format="WAV", |
|
subtype="FLOAT", |
|
) |
|
wave_bio.seek(0) |
|
|
|
return wave_bio.getvalue() |
|
|
|
|
|
def generate_sine_wave_bytes( |
|
seconds: float, samplerate: int, frequency: float |
|
) -> bytes: |
|
wave_ndarray = generate_sine_wave_ndarray(seconds, samplerate, frequency) |
|
return encode_bytes(wave_ndarray, samplerate) |
|
|
|
|
|
def encode_base64(wave_bytes: bytes) -> str: |
|
return base64.standard_b64encode(wave_bytes).decode("utf-8") |
|
|
|
|
|
def generate_sine_wave_base64(seconds: float, samplerate: int, frequency: float) -> str: |
|
wave_bytes = generate_sine_wave_bytes(seconds, samplerate, frequency) |
|
wave_base64 = encode_base64(wave_bytes) |
|
return wave_base64 |
|
|
|
|
|
class TestConnectBase64Waves(TestCase): |
|
def test_connect(self): |
|
samplerate = 1000 |
|
wave = generate_sine_wave_ndarray( |
|
seconds=2, samplerate=samplerate, frequency=10 |
|
) |
|
wave_base64 = encode_base64(encode_bytes(wave, samplerate=samplerate)) |
|
|
|
wave_x2_ref = np.concatenate([wave, wave]) |
|
|
|
wave_x2, _ = connect_base64_waves(waves=[wave_base64, wave_base64]) |
|
|
|
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) |
|
|
|
self.assertTrue((wave_x2_ref == wave_x2).all()) |
|
|
|
def test_no_wave_error(self): |
|
self.assertRaises(ConnectBase64WavesException, connect_base64_waves, waves=[]) |
|
|
|
def test_invalid_base64_error(self): |
|
wave_1000hz = generate_sine_wave_base64( |
|
seconds=2, samplerate=1000, frequency=10 |
|
) |
|
wave_1000hz_broken = wave_1000hz[1:] |
|
|
|
self.assertRaises( |
|
ConnectBase64WavesException, |
|
connect_base64_waves, |
|
waves=[ |
|
wave_1000hz_broken, |
|
], |
|
) |
|
|
|
def test_invalid_wave_file_error(self): |
|
wave_1000hz = generate_sine_wave_bytes(seconds=2, samplerate=1000, frequency=10) |
|
wave_1000hz_broken_bytes = wave_1000hz[1:] |
|
wave_1000hz_broken = encode_base64(wave_1000hz_broken_bytes) |
|
|
|
self.assertRaises( |
|
ConnectBase64WavesException, |
|
connect_base64_waves, |
|
waves=[ |
|
wave_1000hz_broken, |
|
], |
|
) |
|
|
|
def test_different_frequency(self): |
|
wave_24000hz = generate_sine_wave_ndarray( |
|
seconds=1, samplerate=24000, frequency=10 |
|
) |
|
wave_1000hz = generate_sine_wave_ndarray( |
|
seconds=2, samplerate=1000, frequency=10 |
|
) |
|
wave_24000_base64 = encode_base64(encode_bytes(wave_24000hz, samplerate=24000)) |
|
wave_1000_base64 = encode_base64(encode_bytes(wave_1000hz, samplerate=1000)) |
|
|
|
wave_1000hz_to2400hz = resample(wave_1000hz, 24000 * len(wave_1000hz) // 1000) |
|
wave_x2_ref = np.concatenate([wave_24000hz, wave_1000hz_to2400hz]) |
|
|
|
wave_x2, _ = connect_base64_waves(waves=[wave_24000_base64, wave_1000_base64]) |
|
|
|
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) |
|
numpy.testing.assert_array_almost_equal(wave_x2_ref, wave_x2) |
|
|
|
def test_different_channels(self): |
|
wave_1000hz = generate_sine_wave_ndarray( |
|
seconds=2, samplerate=1000, frequency=10 |
|
) |
|
wave_2ch_1000hz = np.array([wave_1000hz, wave_1000hz]).T |
|
wave_1ch_base64 = encode_base64(encode_bytes(wave_1000hz, samplerate=1000)) |
|
wave_2ch_base64 = encode_base64(encode_bytes(wave_2ch_1000hz, samplerate=1000)) |
|
|
|
wave_x2_ref = np.concatenate([wave_2ch_1000hz, wave_2ch_1000hz]) |
|
|
|
wave_x2, _ = connect_base64_waves(waves=[wave_1ch_base64, wave_2ch_base64]) |
|
|
|
self.assertEqual(wave_x2_ref.shape, wave_x2.shape) |
|
self.assertTrue((wave_x2_ref == wave_x2).all()) |
|
|