shaipeerms commited on
Commit
bdf49c6
1 Parent(s): 1c7c790

Added validate_zip tests

Browse files
Files changed (1) hide show
  1. validation.py +85 -20
validation.py CHANGED
@@ -1,11 +1,11 @@
1
  import json
2
  from pathlib import Path
3
  from zipfile import ZipFile
4
- from typing import List, Dict, Any
5
  from tempfile import TemporaryDirectory
6
 
7
 
8
- def validate_zip(submission_track: str, submission_zip: str):
9
  """
10
  Validates the submission format and contents
11
  Args:
@@ -35,20 +35,19 @@ def validate_notsofar_submission(submission_dir: Path):
35
  Raises:
36
  ValueError: if the submission zip is invalid
37
  """
38
- submission_file_names = ['tc_orc_wer_hyp.json', 'tcp_wer_hyp.json']
 
39
  fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
40
 
41
- for file_name in submission_file_names:
42
  file_path = submission_dir / file_name
43
  if not file_path.exists():
44
- raise ValueError(f'Missing {file_name}')
45
- with open(file_path, 'r') as json_file:
46
- json_data: List[Dict[str, Any]] = json.load(json_file)
47
- if not isinstance(json_data, list):
48
- raise ValueError(f'Invalid `{file_name}` format, expecting a list of entries')
49
- for data in json_data:
50
- if not all(field in data for field in fields):
51
- raise ValueError(f'Invalid `{file_name}` format, fields: {fields} are required in each entry')
52
 
53
 
54
  def validate_dasr_submission(submission_dir: Path):
@@ -64,16 +63,82 @@ def validate_dasr_submission(submission_dir: Path):
64
  fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
65
 
66
  if not (submission_dir / 'dev').exists():
67
- raise ValueError('Missing dev directory, expecting a directory named `dev` with the submission files in it.')
68
 
69
  for file_name in submission_file_names:
70
  file_path = submission_dir / 'dev' / file_name
71
  if not file_path.exists():
72
  raise ValueError(f'Missing {file_name}')
73
- with open(file_path, 'r') as json_file:
74
- json_data: List[Dict[str, Any]] = json.load(json_file)
75
- if not isinstance(json_data, list):
76
- raise ValueError(f'Invalid `{file_name}` format, expecting a list of entries')
77
- for data in json_data:
78
- if not all(field in data for field in fields):
79
- raise ValueError(f'Invalid `{file_name}` format, fields: {fields} are required in each entry')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  from pathlib import Path
3
  from zipfile import ZipFile
4
+ from typing import List, Dict, Any, Union
5
  from tempfile import TemporaryDirectory
6
 
7
 
8
+ def validate_zip(submission_track: str, submission_zip: Union[Path, str]):
9
  """
10
  Validates the submission format and contents
11
  Args:
 
35
  Raises:
36
  ValueError: if the submission zip is invalid
37
  """
38
+ submission_file_names = ['tcp_wer_hyp.json']
39
+ optional_file_names = ['tc_orc_wer_ref.json']
40
  fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
41
 
42
+ for file_name in submission_file_names + optional_file_names:
43
  file_path = submission_dir / file_name
44
  if not file_path.exists():
45
+ if file_name in submission_file_names:
46
+ raise ValueError(f'Missing {file_name}')
47
+ else:
48
+ continue
49
+
50
+ validate_json_file_structure(file_path, fields)
 
 
51
 
52
 
53
  def validate_dasr_submission(submission_dir: Path):
 
63
  fields = ['session_id', 'words', 'speaker', 'start_time', 'end_time']
64
 
65
  if not (submission_dir / 'dev').exists():
66
+ raise ValueError('Missing `dev` directory, expecting a directory named `dev` with the submission files in it.')
67
 
68
  for file_name in submission_file_names:
69
  file_path = submission_dir / 'dev' / file_name
70
  if not file_path.exists():
71
  raise ValueError(f'Missing {file_name}')
72
+
73
+ validate_json_file_structure(file_path, fields)
74
+
75
+
76
+ def validate_json_file_structure(file_path: Path, fields: List[str]):
77
+ """
78
+ Validates the structure of a json file
79
+ Args:
80
+ file_path: path to the json file
81
+ fields: list of fields that are required in each entry
82
+ Raises:
83
+ ValueError: if the json file is invalid
84
+
85
+ """
86
+ with open(file_path, 'r') as json_file:
87
+ json_data: List[Dict[str, Any]] = json.load(json_file)
88
+ if not isinstance(json_data, list):
89
+ raise ValueError(f'Invalid `{file_path.name}` format, expecting a list of entries')
90
+ for data in json_data:
91
+ if not all(field in data for field in fields):
92
+ raise ValueError(f'Invalid `{file_path.name}` format, fields: {fields} are required in each entry')
93
+
94
+
95
+ def test_validate_zip(data_samples: int = 10):
96
+ import os
97
+ with TemporaryDirectory() as temp_dir:
98
+ submission_zip = Path(temp_dir) / 'submission.zip'
99
+ valid_data = [{'session_id': 'session_id', 'words': 'words', 'speaker': 'speaker',
100
+ 'start_time': 0.0, 'end_time': 1.0} for _ in range(data_samples)]
101
+ invalid_data = [{'session_id': 'session_id', 'words': 'words', 'start_time': 0.0} for _ in range(data_samples)]
102
+
103
+ def create_test_data(submission_track: str, data: List[Dict[str, Any]], json_file_names: List[str],
104
+ parent_zip_dir: str = None):
105
+ submission_dir = Path(temp_dir) / submission_track
106
+ os.makedirs(submission_dir, exist_ok=True)
107
+ with ZipFile(submission_zip, 'w') as submission_zip_file:
108
+ for json_file_name in json_file_names:
109
+ if parent_zip_dir:
110
+ json_file_name = str(Path(parent_zip_dir) / json_file_name)
111
+ submission_zip_file.writestr(json_file_name, json.dumps(data))
112
+ return submission_track, submission_zip
113
+
114
+ def test(track: str, data: List[Dict[str, Any]], json_file_names: List[str], expected_error: bool,
115
+ parent_zip_dir=None):
116
+ try:
117
+ validate_zip(*create_test_data(track, data, json_file_names, parent_zip_dir))
118
+ assert not expected_error, f'Expected error for {track}'
119
+ except ValueError as e:
120
+ assert expected_error, f'Unexpected error for {track}'
121
+
122
+ # NOTSOFAR-SC
123
+ test('NOTSOFAR-SC', valid_data, ['tcp_wer_hyp.json'], False)
124
+ test('NOTSOFAR-SC', valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'], False)
125
+ test('NOTSOFAR-SC', invalid_data, ['tcp_wer_hyp.json'], True)
126
+ test('NOTSOFAR-SC', invalid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'], True)
127
+
128
+ # NOTSOFAR-MC
129
+ test('NOTSOFAR-MC', valid_data, ['tcp_wer_hyp.json'], False)
130
+ test('NOTSOFAR-MC', valid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'], False)
131
+ test('NOTSOFAR-MC', invalid_data, ['tcp_wer_hyp.json'], True)
132
+ test('NOTSOFAR-MC', invalid_data, ['tcp_wer_hyp.json', 'tc_orc_wer_ref.json'], True)
133
+
134
+ # DASR-Constrained-LM
135
+ test('DASR-Constrained-LM', valid_data, ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], False, 'dev')
136
+ test('DASR-Constrained-LM', invalid_data, ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], True, 'dev')
137
+
138
+ # DASR-Unconstrained-LM
139
+ test('DASR-Unconstrained-LM', valid_data, ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], False, 'dev')
140
+ test('DASR-Unconstrained-LM', invalid_data, ['chime6.json', 'dipco.json', 'mixer6.json', 'notsofar1.json'], True, 'dev')
141
+
142
+
143
+ if __name__ == '__main__':
144
+ test_validate_zip()