from transformers import HfArgumentParser from dataclasses import dataclass, field import logging from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time from segment import ( generate_segments, extract_segment, MIN_SAFETY_TOKENS, SAFETY_TOKENS_PERCENTAGE, word_start, word_end, SegmentationArguments ) import preprocess from errors import TranscriptError from model import get_model_tokenizer_classifier, InferenceArguments logging.basicConfig() logger = logging.getLogger(__name__) @dataclass class PredictArguments(InferenceArguments): video_id: str = field( default=None, metadata={ 'help': 'Video to predict segments for'} ) def __post_init__(self): if self.video_id is not None: self.video_ids.append(self.video_id) super().__post_init__() MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3) MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds # Any prediction whose start time is <= this will be set to start at 0 START_TIME_ZERO_THRESHOLD = 0.08 def filter_and_add_probabilities(predictions, classifier, min_probability): """Use classifier to filter predictions""" if not predictions: return predictions # We update the predicted category from the extractive transformer # if the classifier is confident enough it is another category texts = [ preprocess.clean_text(' '.join([x['text'] for x in pred['words']])) for pred in predictions ] classifications = classifier(texts) filtered_predictions = [] for prediction, probabilities in zip(predictions, classifications): predicted_probabilities = { p['label'].lower(): p['score'] for p in probabilities} # Get best category + probability classifier_category = max( predicted_probabilities, key=predicted_probabilities.get) classifier_probability = predicted_probabilities[classifier_category] if (prediction['category'] not in predicted_probabilities) \ or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param # Unknown category or we are confident enough to overrule, # so change category to what was predicted by classifier prediction['category'] = classifier_category if prediction['category'] == 'none': continue # Ignore if categorised as nothing prediction['probability'] = predicted_probabilities[prediction['category']] if min_probability is not None and prediction['probability'] < min_probability: continue # Ignore if below threshold # TODO add probabilities, but remove None and normalise rest prediction['probabilities'] = predicted_probabilities # if prediction['probability'] < classifier_args.min_probability: # continue filtered_predictions.append(prediction) return filtered_predictions def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None): # Allow words to be passed in so that we don't have to get the words if we already have them if words is None: words = preprocess.get_words(video_id) if not words: raise TranscriptError('Unable to retrieve transcript') segments = generate_segments( words, tokenizer, segmentation_args ) predictions = segments_to_predictions(segments, model, tokenizer) # Add words back to time_ranges for prediction in predictions: # Stores words in the range prediction['words'] = extract_segment( words, prediction['start'], prediction['end']) if classifier is not None: predictions = filter_and_add_probabilities( predictions, classifier, min_probability) return predictions def greedy_match(list, sublist): # Return index and length of longest matching sublist best_i = -1 best_j = -1 best_k = 0 for i in range(len(list)): # Start position in main list for j in range(len(sublist)): # Start position in sublist for k in range(len(sublist)-j, 0, -1): # Width of sublist window if k > best_k and list[i:i+k] == sublist[j:j+k]: best_i, best_j, best_k = i, j, k break # Since window size decreases return best_i, best_j, best_k def predict_sponsor_from_texts(texts, model, tokenizer): clean_texts = list(map(preprocess.clean_text, texts)) return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer) def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer): """Given a body of text, predict the words which are part of the sponsor""" model_device = next(model.parameters()).device decoded_outputs = [] # Do individually, to avoid running out of memory for long videos for cleaned_words in cleaned_texts: text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \ ' '.join(cleaned_words) input_ids = tokenizer(text, return_tensors='pt', truncation=True).input_ids.to(model_device) # Optimise output length so that we do not generate unnecessarily long texts max_out_len = round(min( max( len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE, len(input_ids[0]) + MIN_SAFETY_TOKENS ), model.model_dim) ) outputs = model.generate(input_ids, max_length=max_out_len) decoded_outputs.append(tokenizer.decode( outputs[0], skip_special_tokens=True)) return decoded_outputs def segments_to_predictions(segments, model, tokenizer): predicted_time_ranges = [] cleaned_texts = [ [x['cleaned'] for x in cleaned_segment] for cleaned_segment in segments ] sponsorship_texts = predict_sponsor_from_cleaned_texts( cleaned_texts, model, tokenizer) matches = extract_sponsor_matches(sponsorship_texts) for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments): for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos matched_text = match['text'].split() i1, j1, k1 = greedy_match( cleaned_batch, matched_text[:MATCH_WINDOW]) i2, j2, k2 = greedy_match( cleaned_batch, matched_text[-MATCH_WINDOW:]) extracted_words = segment[i1:i2+k2] if not extracted_words: continue predicted_time_ranges.append({ 'start': word_start(extracted_words[0]), 'end': word_end(extracted_words[-1]), 'category': match['category'] }) # Necessary to sort matches by start time predicted_time_ranges.sort(key=word_start) # Merge overlapping predictions and sponsorships that are close together # Caused by model having max input size prev_prediction = None final_predicted_time_ranges = [] for range in predicted_time_ranges: start_time = range['start'] if range['start'] > START_TIME_ZERO_THRESHOLD else 0 end_time = range['end'] if prev_prediction is not None and \ (start_time <= prev_prediction['end'] <= end_time or # Merge overlapping segments (range['category'] == prev_prediction['category'] # Merge disconnected segments if same category and within threshold and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)): # Extend last prediction range final_predicted_time_ranges[-1]['end'] = end_time else: # No overlap, is a new prediction final_predicted_time_ranges.append({ 'start': start_time, 'end': end_time, 'category': range['category'] }) prev_prediction = range return final_predicted_time_ranges def main(): # Test on unseen data logger.setLevel(logging.DEBUG) hf_parser = HfArgumentParser(( PredictArguments, SegmentationArguments, GeneralArguments )) predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses() if not predict_args.video_ids: logger.error( 'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.') return model, tokenizer, classifier = get_model_tokenizer_classifier( predict_args, general_args) for video_id in predict_args.video_ids: try: predictions = predict(video_id, model, tokenizer, segmentation_args, classifier=classifier, min_probability=predict_args.min_probability) except TranscriptError: logger.warning(f'No transcript available for {video_id}') continue video_url = f'https://www.youtube.com/watch?v={video_id}' if not predictions: logger.info(f'No predictions found for {video_url}') continue # TODO use predict_args.output_as_json print(len(predictions), 'predictions found for', video_url) for index, prediction in enumerate(predictions, start=1): print(f'Prediction #{index}:') print('Text: "', ' '.join([w['text'] for w in prediction['words']]), '"', sep='') print('Time:', seconds_to_time( prediction['start']), '\u2192', seconds_to_time(prediction['end'])) print('Category:', prediction.get('category')) if 'probability' in prediction: print('Probability:', prediction['probability']) print() print() if __name__ == '__main__': main()