Joshua Lochner commited on
Commit
a294fb2
1 Parent(s): 31d605f

Improve output of evaluation script

Browse files
Files changed (1) hide show
  1. src/evaluate.py +72 -28
src/evaluate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from utils import jaccard
2
  from datasets import load_dataset
3
  from transformers import (
@@ -5,10 +6,10 @@ from transformers import (
5
  AutoTokenizer,
6
  HfArgumentParser
7
  )
8
- from preprocess import DatasetArguments, ProcessedArguments, get_words
9
  from shared import device, GeneralArguments
10
  from predict import ClassifierArguments, predict, TrainingOutputArguments
11
- from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
12
  import pandas as pd
13
  from dataclasses import dataclass, field
14
  from typing import Optional
@@ -16,6 +17,7 @@ from tqdm import tqdm
16
  import json
17
  import os
18
  import random
 
19
 
20
 
21
  @dataclass
@@ -29,11 +31,8 @@ class EvaluationArguments(TrainingOutputArguments):
29
  'help': 'The number of videos to test on'
30
  }
31
  )
32
-
33
- data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
34
- dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
35
- 'validation_file']
36
-
37
  output_file: Optional[str] = field(
38
  default='metrics.csv',
39
  metadata={
@@ -98,13 +97,13 @@ def calculate_metrics(labelled_words, predictions):
98
 
99
  if predicted_sponsor:
100
  # total_positive_time += duration
101
- if word['category'] is not None: # Is actual sponsor
102
  metrics['true_positive'] += duration
103
  else:
104
  metrics['false_positive'] += duration
105
  else:
106
  # total_negative_time += duration
107
- if word['category'] is not None: # Is actual sponsor
108
  metrics['false_negative'] += duration
109
  else:
110
  metrics['true_negative'] += duration
@@ -141,34 +140,38 @@ def calculate_metrics(labelled_words, predictions):
141
  def main():
142
  hf_parser = HfArgumentParser((
143
  EvaluationArguments,
144
- ProcessedArguments,
145
  SegmentationArguments,
146
  ClassifierArguments,
147
  GeneralArguments
148
  ))
149
 
150
- evaluation_args, processed_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
151
-
152
- model = AutoModelForSeq2SeqLM.from_pretrained(evaluation_args.model_path)
153
- model.to(device())
154
-
155
- tokenizer = AutoTokenizer.from_pretrained(evaluation_args.model_path)
156
 
157
- dataset = load_dataset('json', data_files=os.path.join(
158
- evaluation_args.data_dir, evaluation_args.dataset))['train']
159
 
160
- video_ids = [row['video_id'] for row in dataset]
161
- random.shuffle(video_ids) # TODO Make param
162
-
163
- if evaluation_args.max_videos is not None:
164
- video_ids = video_ids[:evaluation_args.max_videos]
165
 
166
  # Load labelled data:
167
  final_path = os.path.join(
168
- processed_args.processed_dir, processed_args.processed_file)
169
 
170
  with open(final_path) as fp:
171
  final_data = json.load(fp)
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  total_accuracy = 0
174
  total_precision = 0
@@ -179,9 +182,12 @@ def main():
179
 
180
  try:
181
  with tqdm(video_ids) as progress:
182
- for video_id in progress:
 
183
  progress.set_description(f'Processing {video_id}')
184
  sponsor_segments = final_data.get(video_id, [])
 
 
185
 
186
  words = get_words(video_id)
187
  if not words:
@@ -211,9 +217,47 @@ def main():
211
 
212
  labelled_predicted_segments = attach_predictions_to_sponsor_segments(
213
  predictions, sponsor_segments)
214
- for seg in labelled_predicted_segments:
215
- if seg['best_prediction'] is None:
216
- print('\nNo match found for', seg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  except KeyboardInterrupt:
219
  pass
 
1
+ from model import get_model_tokenizer
2
  from utils import jaccard
3
  from datasets import load_dataset
4
  from transformers import (
 
6
  AutoTokenizer,
7
  HfArgumentParser
8
  )
9
+ from preprocess import DatasetArguments, get_words
10
  from shared import device, GeneralArguments
11
  from predict import ClassifierArguments, predict, TrainingOutputArguments
12
+ from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
13
  import pandas as pd
14
  from dataclasses import dataclass, field
15
  from typing import Optional
 
17
  import json
18
  import os
19
  import random
20
+ from shared import seconds_to_time
21
 
22
 
23
  @dataclass
 
31
  'help': 'The number of videos to test on'
32
  }
33
  )
34
+ start_index: int = field(default=None, metadata={
35
+ 'help': 'Video to start the evaluation at.'})
 
 
 
36
  output_file: Optional[str] = field(
37
  default='metrics.csv',
38
  metadata={
 
97
 
98
  if predicted_sponsor:
99
  # total_positive_time += duration
100
+ if word.get('category') is not None: # Is actual sponsor
101
  metrics['true_positive'] += duration
102
  else:
103
  metrics['false_positive'] += duration
104
  else:
105
  # total_negative_time += duration
106
+ if word.get('category') is not None: # Is actual sponsor
107
  metrics['false_negative'] += duration
108
  else:
109
  metrics['true_negative'] += duration
 
140
  def main():
141
  hf_parser = HfArgumentParser((
142
  EvaluationArguments,
143
+ DatasetArguments,
144
  SegmentationArguments,
145
  ClassifierArguments,
146
  GeneralArguments
147
  ))
148
 
149
+ evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
 
 
 
 
 
150
 
151
+ model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
 
152
 
153
+ # # TODO find better way of evaluating videos not trained on
154
+ # dataset = load_dataset('json', data_files=os.path.join(
155
+ # dataset_args.data_dir, dataset_args.validation_file))['train']
156
+ # video_ids = [row['video_id'] for row in dataset]
 
157
 
158
  # Load labelled data:
159
  final_path = os.path.join(
160
+ dataset_args.data_dir, dataset_args.processed_file)
161
 
162
  with open(final_path) as fp:
163
  final_data = json.load(fp)
164
+ video_ids = list(final_data.keys())
165
+
166
+ random.shuffle(video_ids)
167
+
168
+ if evaluation_args.start_index is not None:
169
+ video_ids = video_ids[evaluation_args.start_index:]
170
+
171
+ if evaluation_args.max_videos is not None:
172
+ video_ids = video_ids[:evaluation_args.max_videos]
173
+
174
+ # TODO option to choose categories
175
 
176
  total_accuracy = 0
177
  total_precision = 0
 
182
 
183
  try:
184
  with tqdm(video_ids) as progress:
185
+ for video_index, video_id in enumerate(progress):
186
+
187
  progress.set_description(f'Processing {video_id}')
188
  sponsor_segments = final_data.get(video_id, [])
189
+ if not sponsor_segments:
190
+ continue # Ignore empty
191
 
192
  words = get_words(video_id)
193
  if not words:
 
217
 
218
  labelled_predicted_segments = attach_predictions_to_sponsor_segments(
219
  predictions, sponsor_segments)
220
+
221
+ # Identify possible issues:
222
+ missed_segments = [
223
+ prediction for prediction in predictions if prediction['best_sponsorship'] is None]
224
+ incorrect_segments = [
225
+ seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
226
+
227
+ if missed_segments or incorrect_segments:
228
+ print('Issues identified for',
229
+ video_id, f'(#{video_index})')
230
+ # Potentially missed segments (model predicted, but not in database)
231
+ if missed_segments:
232
+ print(' - Missed segments:')
233
+ for i, missed_segment in enumerate(missed_segments, start=1):
234
+ print(f'\t#{i}:', seconds_to_time(
235
+ missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
236
+ print('\t\tText: "', ' '.join(
237
+ [w['text'] for w in missed_segment['words']]), '"', sep='')
238
+ print('\t\tCategory:',
239
+ missed_segment.get('category'))
240
+ print('\t\tProbability:',
241
+ missed_segment.get('probability'))
242
+
243
+ # Potentially incorrect segments (model didn't predict, but in database)
244
+ if incorrect_segments:
245
+ print(' - Incorrect segments:')
246
+ for i, incorrect_segment in enumerate(incorrect_segments, start=1):
247
+ print(f'\t#{i}:', seconds_to_time(
248
+ incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
249
+
250
+ seg_words = extract_segment(
251
+ words, incorrect_segment['start'], incorrect_segment['end'])
252
+ print('\t\tText: "', ' '.join(
253
+ [w['text'] for w in seg_words]), '"', sep='')
254
+ print('\t\tUUID:', incorrect_segment['uuid'])
255
+ print('\t\tCategory:',
256
+ incorrect_segment['category'])
257
+ print('\t\tVotes:', incorrect_segment['votes'])
258
+ print('\t\tViews:', incorrect_segment['views'])
259
+ print('\t\tLocked:', incorrect_segment['locked'])
260
+ print()
261
 
262
  except KeyboardInterrupt:
263
  pass