sponsorblock-ml / src /evaluate.py
Joshua Lochner
Fix incorrect segment output format
d7725ec
raw
history blame
16.9 kB
from model import get_model_tokenizer_classifier, InferenceArguments
from utils import jaccard, safe_print
from transformers import HfArgumentParser
from preprocess import get_words, clean_text
from shared import GeneralArguments, DatasetArguments
from predict import predict
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
import pandas as pd
from dataclasses import dataclass, field
from typing import Optional
from tqdm import tqdm
import json
import os
import random
from shared import seconds_to_time
from urllib.parse import quote
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
@dataclass
class EvaluationArguments(InferenceArguments):
"""Arguments pertaining to how evaluation will occur."""
output_file: Optional[str] = field(
default='metrics.csv',
metadata={
'help': 'Save metrics to output file'
}
)
skip_missing: bool = field(
default=False,
metadata={
'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
}
)
skip_incorrect: bool = field(
default=False,
metadata={
'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
}
)
def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
"""Attach sponsor segments to closest prediction"""
for prediction in predictions:
prediction['best_overlap'] = 0
prediction['best_sponsorship'] = None
# Assign predictions to actual (labelled) sponsored segments
for sponsor_segment in sponsor_segments:
j = jaccard(prediction['start'], prediction['end'],
sponsor_segment['start'], sponsor_segment['end'])
if prediction['best_overlap'] < j:
prediction['best_overlap'] = j
prediction['best_sponsorship'] = sponsor_segment
return sponsor_segments
def calculate_metrics(labelled_words, predictions):
metrics = {
'true_positive': 0, # Is sponsor, predicted sponsor
# Is sponsor, predicted not sponsor (i.e., missed it - bad)
'false_negative': 0,
# Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
'false_positive': 0,
'true_negative': 0, # Is not sponsor, predicted not sponsor
}
metrics['video_duration'] = word_end(
labelled_words[-1])-word_start(labelled_words[0])
for index, word in enumerate(labelled_words):
if index >= len(labelled_words) - 1:
continue
duration = word_end(word) - word_start(word)
predicted_sponsor = False
for p in predictions:
# Is in some prediction
if p['start'] <= word['start'] <= p['end']:
predicted_sponsor = True
break
if predicted_sponsor:
# total_positive_time += duration
if word.get('category') is not None: # Is actual sponsor
metrics['true_positive'] += duration
else:
metrics['false_positive'] += duration
else:
# total_negative_time += duration
if word.get('category') is not None: # Is actual sponsor
metrics['false_negative'] += duration
else:
metrics['true_negative'] += duration
# NOTE In cases where we encounter division by 0, we say that the value is 1
# https://stats.stackexchange.com/a/1775
# (Precision) TP+FP=0: means that all instances were predicted as negative
# (Recall) TP+FN=0: means that there were no positive cases in the input data
# The fraction of predictions our model got right
# Can simplify, but use full formula
z = metrics['true_positive'] + metrics['true_negative'] + \
metrics['false_positive'] + metrics['false_negative']
metrics['accuracy'] = (
(metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1
# What proportion of positive identifications was actually correct?
z = metrics['true_positive'] + metrics['false_positive']
metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1
# What proportion of actual positives was identified correctly?
z = metrics['true_positive'] + metrics['false_negative']
metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1
# https://deepai.org/machine-learning-glossary-and-terms/f-score
s = metrics['precision'] + metrics['recall']
metrics['f-score'] = (2 * (metrics['precision'] *
metrics['recall']) / s) if s > 0 else 0
return metrics
def main():
logger.setLevel(logging.DEBUG)
hf_parser = HfArgumentParser((
EvaluationArguments,
DatasetArguments,
SegmentationArguments,
GeneralArguments
))
evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
logger.error('ERROR: Nothing to do')
return
# Load labelled data:
final_path = os.path.join(
dataset_args.data_dir, dataset_args.processed_file)
if not os.path.exists(final_path):
logger.error('ERROR: Processed database not found.\n'
f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
return
model, tokenizer, classifier = get_model_tokenizer_classifier(
evaluation_args, general_args)
with open(final_path) as fp:
final_data = json.load(fp)
if evaluation_args.video_ids: # Use specified
video_ids = evaluation_args.video_ids
else: # Use items found in preprocessed database
video_ids = list(final_data.keys())
random.shuffle(video_ids)
if evaluation_args.start_index is not None:
video_ids = video_ids[evaluation_args.start_index:]
if evaluation_args.max_videos is not None:
video_ids = video_ids[:evaluation_args.max_videos]
out_metrics = []
all_metrics = {}
if not evaluation_args.skip_missing:
all_metrics['total_prediction_accuracy'] = 0
all_metrics['total_prediction_precision'] = 0
all_metrics['total_prediction_recall'] = 0
all_metrics['total_prediction_fscore'] = 0
if not evaluation_args.skip_incorrect:
all_metrics['classifier_segment_correct'] = 0
all_metrics['classifier_segment_count'] = 0
metric_count = 0
postfix_info = {}
try:
with tqdm(video_ids) as progress:
for video_index, video_id in enumerate(progress):
progress.set_description(f'Processing {video_id}')
words = get_words(video_id)
if not words:
continue
# Get labels
sponsor_segments = final_data.get(video_id)
# Reset previous
missed_segments = []
incorrect_segments = []
current_metrics = {
'video_id': video_id
}
metric_count += 1
if not evaluation_args.skip_missing: # Make predictions
predictions = predict(video_id, model, tokenizer, segmentation_args,
classifier=classifier,
min_probability=evaluation_args.min_probability)
if sponsor_segments:
labelled_words = add_labels_to_words(
words, sponsor_segments)
current_metrics.update(
calculate_metrics(labelled_words, predictions))
all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
all_metrics['total_prediction_precision'] += current_metrics['precision']
all_metrics['total_prediction_recall'] += current_metrics['recall']
all_metrics['total_prediction_fscore'] += current_metrics['f-score']
# Just for display purposes
postfix_info.update({
'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
'precision': all_metrics['total_prediction_precision']/metric_count,
'recall': all_metrics['total_prediction_recall']/metric_count,
'f-score': all_metrics['total_prediction_fscore']/metric_count,
})
sponsor_segments = attach_predictions_to_sponsor_segments(
predictions, sponsor_segments)
# Identify possible issues:
for prediction in predictions:
if prediction['best_sponsorship'] is not None:
continue
prediction_words = prediction.pop('words', [])
# Attach original text to missed segments
prediction['text'] = ' '.join(
x['text'] for x in prediction_words)
missed_segments.append(prediction)
else:
# Not in database (all segments missed)
missed_segments = predictions
if not evaluation_args.skip_incorrect and sponsor_segments:
# Check for incorrect segments using the classifier
segments_to_check = []
cleaned_texts = [] # Texts to send through tokenizer
for sponsor_segment in sponsor_segments:
segment_words = extract_segment(
words, sponsor_segment['start'], sponsor_segment['end'])
sponsor_segment['text'] = ' '.join(
x['text'] for x in segment_words)
duration = sponsor_segment['end'] - \
sponsor_segment['start']
wps = (len(segment_words) /
duration) if duration > 0 else 0
if wps < 1.5:
continue
# Do not worry about those that are locked or have enough votes
# or segment['votes'] > 5:
if sponsor_segment['locked']:
continue
cleaned_texts.append(
clean_text(sponsor_segment['text']))
segments_to_check.append(sponsor_segment)
if segments_to_check: # Some segments to check
segments_scores = classifier(cleaned_texts)
num_correct = 0
for segment, scores in zip(segments_to_check, segments_scores):
fixed_scores = {
score['label']: score['score']
for score in scores
}
all_metrics['classifier_segment_count'] += 1
prediction = max(scores, key=lambda x: x['score'])
predicted_category = prediction['label'].lower()
if predicted_category == segment['category']:
num_correct += 1
continue # Ignore correct segments
segment.update({
'predicted': predicted_category,
'scores': fixed_scores
})
incorrect_segments.append(segment)
current_metrics['num_segments'] = len(
segments_to_check)
current_metrics['classified_correct'] = num_correct
all_metrics['classifier_segment_correct'] += num_correct
if all_metrics['classifier_segment_count'] > 0:
postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
all_metrics['classifier_segment_count']
out_metrics.append(current_metrics)
progress.set_postfix(postfix_info)
if missed_segments or incorrect_segments:
if evaluation_args.output_as_json:
to_print = {'video_id': video_id}
if missed_segments:
to_print['missed'] = missed_segments
if incorrect_segments:
to_print['incorrect'] = incorrect_segments
safe_print(json.dumps(to_print))
else:
safe_print(
f'Issues identified for {video_id} (#{video_index})')
# Potentially missed segments (model predicted, but not in database)
if missed_segments:
safe_print(' - Missed segments:')
segments_to_submit = []
for i, missed_segment in enumerate(missed_segments, start=1):
safe_print(f'\t#{i}:', seconds_to_time(
missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
safe_print('\t\tText: "',
missed_segment['text'], '"', sep='')
safe_print('\t\tCategory:',
missed_segment.get('category'))
if 'probability' in missed_segment:
safe_print('\t\tProbability:',
missed_segment['probability'])
segments_to_submit.append({
'segment': [missed_segment['start'], missed_segment['end']],
'category': missed_segment['category'].lower(),
'actionType': 'skip'
})
json_data = quote(json.dumps(segments_to_submit))
safe_print(
f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
# Incorrect segments (in database, but incorrectly classified)
if incorrect_segments:
safe_print(' - Incorrect segments:')
for i, incorrect_segment in enumerate(incorrect_segments, start=1):
safe_print(f'\t#{i}:', seconds_to_time(
incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
safe_print(
'\t\tText: "', incorrect_segment['text'], '"', sep='')
safe_print(
'\t\tUUID:', incorrect_segment['uuid'])
safe_print(
'\t\tVotes:', incorrect_segment['votes'])
safe_print(
'\t\tViews:', incorrect_segment['views'])
safe_print('\t\tLocked:',
incorrect_segment['locked'])
safe_print('\t\tCurrent Category:',
incorrect_segment['category'])
safe_print('\t\tPredicted Category:',
incorrect_segment['predicted'])
safe_print('\t\tProbabilities:')
for label, score in incorrect_segment['scores'].items():
safe_print(
f"\t\t\t{label}: {score}")
safe_print()
except KeyboardInterrupt:
pass
df = pd.DataFrame(out_metrics)
df.to_csv(evaluation_args.output_file)
logger.info(df.mean())
if __name__ == '__main__':
main()