CHiME8Challenge / validation.py
shaipeerms's picture
Added validate_zip tests
fbb1d85
raw
history blame
No virus
9.22 kB
import os
import json
import unittest
from pathlib import Path
from zipfile import ZipFile
from typing import List, Dict, Any, Union
from tempfile import TemporaryDirectory
def validate_zip(submission_track: str, submission_zip: Union[Path, str]):
"""
Validates the submission format and contents
Args:
submission_track: the track of the submission
submission_zip: path to the submission zip file
Raises:
ValueError: if the submission zip is invalid
"""
with TemporaryDirectory() as temp_dir:
with ZipFile(submission_zip, 'r') as submission_zip_file:
submission_zip_file.extractall(temp_dir)
submission_dir = Path(temp_dir)
if submission_track in ['NOTSOFAR-SC', 'NOTSOFAR-MC']:
validate_notsofar_submission(submission_dir=submission_dir)
elif submission_track in ['DASR-Constrained-LM', 'DASR-Unconstrained-LM']:
validate_dasr_submission(submission_dir=submission_dir)
else:
raise ValueError(f'Invalid submission track: {submission_track}')
def validate_notsofar_submission(submission_dir: Path):
"""
Validates NOTSOFAR submission format and contents
Args:
submission_dir: path to the submission directory
Raises:
ValueError: if the submission zip is invalid
"""
submission_file_names = ['tcp_wer_hyp.json']
optional_file_names = ['tc_orc_wer_ref.json']
fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
for file_name in submission_file_names + optional_file_names:
file_path = submission_dir / file_name
if not file_path.exists():
if file_name in submission_file_names:
raise ValueError(f'Missing {file_name}')
else:
continue
validate_json_file_structure(file_path, fields)
def validate_dasr_submission(submission_dir: Path):
"""
Validates DASR submission format and contents
Args:
submission_dir: path to the submission directory
Raises:
ValueError: if the submission zip is invalid
"""
submission_file_names = ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']
fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
if not (submission_dir / 'dev').exists():
raise ValueError('Missing `dev` directory, expecting a directory named `dev` with the submission files in it.')
for file_name in submission_file_names:
file_path = submission_dir / 'dev' / file_name
if not file_path.exists():
raise ValueError(f'Missing {file_name}')
validate_json_file_structure(file_path, fields)
def validate_json_file_structure(file_path: Path, fields: List[str]):
"""
Validates the structure of a json file
Args:
file_path: path to the json file
fields: list of fields that are required in each entry
Raises:
ValueError: if the json file is invalid
"""
with open(file_path, 'r') as json_file:
json_data: List[Dict[str, Any]] = json.load(json_file)
if not isinstance(json_data, list):
raise ValueError(f'Invalid `{file_path.name}` format, expecting a list of entries')
for data in json_data:
if not all(field in data for field in fields):
raise ValueError(f'Invalid `{file_path.name}` format, fields: {fields} are required in each entry')
####################################################################################################
# Tests
####################################################################################################
class TestValidateZip(unittest.TestCase):
DATA_SAMPLES = 10
@classmethod
def setUpClass(cls):
cls.valid_data = [{'session_id': 'session_id', 'words': 'words', 'speaker': 'speaker',
'start_time': 0.0, 'end_time': 1.0} for _ in range(cls.DATA_SAMPLES)]
cls.invalid_data = [{'session_id': 'session_id', 'words': 'words',
'start_time': 0.0} for _ in range(cls.DATA_SAMPLES)]
def setUp(self):
self.temp_dir = TemporaryDirectory()
self.submission_zip = Path(self.temp_dir.name) / 'submission.zip'
def create_test_data(self, submission_track: str, data: List[Dict[str, Any]], json_file_names: List[str],
parent_zip_dir: str = None):
submission_dir = Path(self.temp_dir.name) / submission_track
os.makedirs(submission_dir, exist_ok=True)
with ZipFile(self.submission_zip, 'w') as submission_zip_file:
for json_file_name in json_file_names:
if parent_zip_dir:
json_file_name = str(Path(parent_zip_dir) / json_file_name)
submission_zip_file.writestr(json_file_name, json.dumps(data))
return submission_track, self.submission_zip
def tearDown(self):
self.temp_dir.cleanup()
def test_NOTSOFAR_SC_valid_data_tcp(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json'])), None)
def test_NOTSOFAR_SC_valid_data_tcp_and_tcorc(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)
def test_NOTSOFAR_SC_missing_tcp_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.valid_data, ['tc_orc_wer_ref.json']))
def test_NOTSOFAR_SC_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-SC', self.invalid_data, ['tcp_wer_hyp.json']))
def test_NOTSOFAR_MC_valid_data_tcp(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json'])), None)
def test_NOTSOFAR_MC_valid_data_tcp_and_tcorc(self):
self.assertEqual(validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'])), None)
def test_NOTSOFAR_MC_missing_tcp_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.valid_data, ['tc_orc_wer_ref.json']))
def test_NOTSOFAR_MC_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data(
'NOTSOFAR-MC', self.invalid_data, ['tcp_wer_hyp.json']))
def test_DASR_Constrained_LM_valid_data(self):
self.assertEqual(validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json',
'notsofar1.json'], 'dev')), None)
def test_DASR_Constrained_LM_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.invalid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))
def test_DASR_Constrained_LM_missing_dev_dir(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))
def test_DASR_Constrained_LM_missing_json_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Constrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))
def test_DASR_Unconstrained_LM_valid_data(self):
self.assertEqual(validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json',
'notsofar1.json'], 'dev')), None)
def test_DASR_Unconstrained_LM_invalid_data(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.invalid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], 'dev'))
def test_DASR_Unconstrained_LM_missing_dev_dir(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json']))
def test_DASR_Unconstrained_LM_missing_json_file(self):
with self.assertRaises(ValueError):
validate_zip(*self.create_test_data('DASR-Unconstrained-LM', self.valid_data,
['chime6.json', 'dipco.json', 'mixer6.json'], 'dev'))
if __name__ == '__main__':
unittest.main()