Joshua Lochner commited on
Commit
537f2b7
1 Parent(s): 2782b0c

Add `--channel_id` parameter to evaluation script to run evaluation on a channel

Browse files
Files changed (1) hide show
  1. src/evaluate.py +110 -31
src/evaluate.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from model import get_model_tokenizer
2
  from utils import jaccard
3
  from datasets import load_dataset
@@ -41,6 +45,13 @@ class EvaluationArguments(TrainingOutputArguments):
41
  }
42
  )
43
 
 
 
 
 
 
 
 
44
 
45
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
46
  """Attach sponsor segments to closest prediction"""
@@ -138,6 +149,56 @@ def calculate_metrics(labelled_words, predictions):
138
  return metrics
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def main():
142
  hf_parser = HfArgumentParser((
143
  EvaluationArguments,
@@ -162,15 +223,25 @@ def main():
162
 
163
  with open(final_path) as fp:
164
  final_data = json.load(fp)
165
- video_ids = list(final_data.keys())
166
 
167
- random.shuffle(video_ids)
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- if evaluation_args.start_index is not None:
170
- video_ids = video_ids[evaluation_args.start_index:]
171
 
172
- if evaluation_args.max_videos is not None:
173
- video_ids = video_ids[:evaluation_args.max_videos]
174
 
175
  # TODO option to choose categories
176
 
@@ -186,9 +257,11 @@ def main():
186
  for video_index, video_id in enumerate(progress):
187
 
188
  progress.set_description(f'Processing {video_id}')
189
- sponsor_segments = final_data.get(video_id, [])
 
190
  if not sponsor_segments:
191
- continue # Ignore empty
 
192
 
193
  words = get_words(video_id)
194
  if not words:
@@ -198,36 +271,42 @@ def main():
198
  predictions = predict(video_id, model, tokenizer,
199
  segmentation_args, words, classifier_args)
200
 
201
- labelled_words = add_labels_to_words(words, sponsor_segments)
202
- met = calculate_metrics(labelled_words, predictions)
203
- met['video_id'] = video_id
 
 
 
 
204
 
205
- out_metrics.append(met)
 
 
 
206
 
207
- total_accuracy += met['accuracy']
208
- total_precision += met['precision']
209
- total_recall += met['recall']
210
- total_fscore += met['f-score']
 
 
211
 
212
- progress.set_postfix({
213
- 'accuracy': total_accuracy/len(out_metrics),
214
- 'precision': total_precision/len(out_metrics),
215
- 'recall': total_recall/len(out_metrics),
216
- 'f-score': total_fscore/len(out_metrics)
217
- })
218
 
219
- labelled_predicted_segments = attach_predictions_to_sponsor_segments(
220
- predictions, sponsor_segments)
 
 
 
221
 
222
- # Identify possible issues:
223
- missed_segments = [
224
- prediction for prediction in predictions if prediction['best_sponsorship'] is None]
225
- incorrect_segments = [
226
- seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
227
 
228
  if missed_segments or incorrect_segments:
229
- print(
230
- f'Issues identified for https://youtu.be/{video_id} (#{video_index})')
231
  # Potentially missed segments (model predicted, but not in database)
232
  if missed_segments:
233
  print(' - Missed segments:')
 
1
+ import itertools
2
+ import base64
3
+ import re
4
+ import requests
5
  from model import get_model_tokenizer
6
  from utils import jaccard
7
  from datasets import load_dataset
 
45
  }
46
  )
47
 
48
+ channel_id: Optional[str] = field(
49
+ default=None,
50
+ metadata={
51
+ 'help': 'Used to evaluate a channel'
52
+ }
53
+ )
54
+
55
 
56
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
57
  """Attach sponsor segments to closest prediction"""
 
149
  return metrics
150
 
151
 
152
+ # Public innertube key (b64 encoded so that it is not incorrectly flagged)
153
+ INNERTUBE_KEY = base64.b64decode(
154
+ b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
155
+
156
+ YT_CONTEXT = {
157
+ 'client': {
158
+ 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
159
+ 'clientName': 'WEB',
160
+ 'clientVersion': '2.20211221.00.00',
161
+ }
162
+ }
163
+ _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
164
+
165
+
166
+ def get_all_channel_vids(channel_id):
167
+ continuation = None
168
+ while True:
169
+ if continuation is None:
170
+ params = {'list': channel_id.replace('UC', 'UU', 1)}
171
+ response = requests.get(
172
+ 'https://www.youtube.com/playlist', params=params)
173
+ items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
174
+ 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
175
+ else:
176
+ params = {'key': INNERTUBE_KEY}
177
+ data = {
178
+ 'context': YT_CONTEXT,
179
+ 'continuation': continuation
180
+ }
181
+ response = requests.post(
182
+ 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
183
+ items = response.json()[
184
+ 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
185
+
186
+ new_token = None
187
+ for vid in items:
188
+ info = vid.get('playlistVideoRenderer')
189
+ if info:
190
+ yield info['videoId']
191
+ continue
192
+
193
+ info = vid.get('continuationItemRenderer')
194
+ if info:
195
+ new_token = info['continuationEndpoint']['continuationCommand']['token']
196
+
197
+ if new_token is None:
198
+ break
199
+ continuation = new_token
200
+
201
+
202
  def main():
203
  hf_parser = HfArgumentParser((
204
  EvaluationArguments,
 
223
 
224
  with open(final_path) as fp:
225
  final_data = json.load(fp)
 
226
 
227
+ if evaluation_args.channel_id is not None:
228
+ start = evaluation_args.start_index or 0
229
+ end = None if evaluation_args.max_videos is None else start + \
230
+ evaluation_args.max_videos
231
+
232
+ video_ids = list(itertools.islice(get_all_channel_vids(
233
+ evaluation_args.channel_id), start, end))
234
+ print('Found', len(video_ids), 'for channel', evaluation_args.channel_id)
235
+
236
+ else:
237
+ video_ids = list(final_data.keys())
238
+ random.shuffle(video_ids)
239
 
240
+ if evaluation_args.start_index is not None:
241
+ video_ids = video_ids[evaluation_args.start_index:]
242
 
243
+ if evaluation_args.max_videos is not None:
244
+ video_ids = video_ids[:evaluation_args.max_videos]
245
 
246
  # TODO option to choose categories
247
 
 
257
  for video_index, video_id in enumerate(progress):
258
 
259
  progress.set_description(f'Processing {video_id}')
260
+
261
+ sponsor_segments = final_data.get(video_id)
262
  if not sponsor_segments:
263
+ # TODO remove - parse using whole database
264
+ continue
265
 
266
  words = get_words(video_id)
267
  if not words:
 
271
  predictions = predict(video_id, model, tokenizer,
272
  segmentation_args, words, classifier_args)
273
 
274
+ if sponsor_segments:
275
+ labelled_words = add_labels_to_words(
276
+ words, sponsor_segments)
277
+ met = calculate_metrics(labelled_words, predictions)
278
+ met['video_id'] = video_id
279
+
280
+ out_metrics.append(met)
281
 
282
+ total_accuracy += met['accuracy']
283
+ total_precision += met['precision']
284
+ total_recall += met['recall']
285
+ total_fscore += met['f-score']
286
 
287
+ progress.set_postfix({
288
+ 'accuracy': total_accuracy/len(out_metrics),
289
+ 'precision': total_precision/len(out_metrics),
290
+ 'recall': total_recall/len(out_metrics),
291
+ 'f-score': total_fscore/len(out_metrics)
292
+ })
293
 
294
+ labelled_predicted_segments = attach_predictions_to_sponsor_segments(
295
+ predictions, sponsor_segments)
 
 
 
 
296
 
297
+ # Identify possible issues:
298
+ missed_segments = [
299
+ prediction for prediction in predictions if prediction['best_sponsorship'] is None]
300
+ incorrect_segments = [
301
+ seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]
302
 
303
+ else:
304
+ # Not in database (all segments missed)
305
+ missed_segments = predictions
306
+ incorrect_segments = None
 
307
 
308
  if missed_segments or incorrect_segments:
309
+ print(f'Issues identified for {video_id} (#{video_index})')
 
310
  # Potentially missed segments (model predicted, but not in database)
311
  if missed_segments:
312
  print(' - Missed segments:')