Spaces:
Running
Running
from transformers import HfArgumentParser | |
from dataclasses import dataclass, field | |
import logging | |
from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time | |
from segment import ( | |
generate_segments, | |
extract_segment, | |
MIN_SAFETY_TOKENS, | |
SAFETY_TOKENS_PERCENTAGE, | |
word_start, | |
word_end, | |
SegmentationArguments | |
) | |
import preprocess | |
from errors import TranscriptError | |
from model import get_model_tokenizer_classifier, InferenceArguments | |
logging.basicConfig() | |
logger = logging.getLogger(__name__) | |
class PredictArguments(InferenceArguments): | |
video_id: str = field( | |
default=None, | |
metadata={ | |
'help': 'Video to predict segments for'} | |
) | |
def __post_init__(self): | |
if self.video_id is not None: | |
self.video_ids.append(self.video_id) | |
super().__post_init__() | |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3) | |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds | |
# Any prediction whose start time is <= this will be set to start at 0 | |
START_TIME_ZERO_THRESHOLD = 0.08 | |
def filter_and_add_probabilities(predictions, classifier, min_probability): | |
"""Use classifier to filter predictions""" | |
if not predictions: | |
return predictions | |
# We update the predicted category from the extractive transformer | |
# if the classifier is confident enough it is another category | |
texts = [ | |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']])) | |
for pred in predictions | |
] | |
classifications = classifier(texts) | |
filtered_predictions = [] | |
for prediction, probabilities in zip(predictions, classifications): | |
predicted_probabilities = { | |
p['label'].lower(): p['score'] for p in probabilities} | |
# Get best category + probability | |
classifier_category = max( | |
predicted_probabilities, key=predicted_probabilities.get) | |
classifier_probability = predicted_probabilities[classifier_category] | |
if (prediction['category'] not in predicted_probabilities) \ | |
or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param | |
# Unknown category or we are confident enough to overrule, | |
# so change category to what was predicted by classifier | |
prediction['category'] = classifier_category | |
if prediction['category'] == 'none': | |
continue # Ignore if categorised as nothing | |
prediction['probability'] = predicted_probabilities[prediction['category']] | |
if min_probability is not None and prediction['probability'] < min_probability: | |
continue # Ignore if below threshold | |
# TODO add probabilities, but remove None and normalise rest | |
prediction['probabilities'] = predicted_probabilities | |
# if prediction['probability'] < classifier_args.min_probability: | |
# continue | |
filtered_predictions.append(prediction) | |
return filtered_predictions | |
def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None): | |
# Allow words to be passed in so that we don't have to get the words if we already have them | |
if words is None: | |
words = preprocess.get_words(video_id) | |
if not words: | |
raise TranscriptError('Unable to retrieve transcript') | |
segments = generate_segments( | |
words, | |
tokenizer, | |
segmentation_args | |
) | |
predictions = segments_to_predictions(segments, model, tokenizer) | |
# Add words back to time_ranges | |
for prediction in predictions: | |
# Stores words in the range | |
prediction['words'] = extract_segment( | |
words, prediction['start'], prediction['end']) | |
if classifier is not None: | |
predictions = filter_and_add_probabilities( | |
predictions, classifier, min_probability) | |
return predictions | |
def greedy_match(list, sublist): | |
# Return index and length of longest matching sublist | |
best_i = -1 | |
best_j = -1 | |
best_k = 0 | |
for i in range(len(list)): # Start position in main list | |
for j in range(len(sublist)): # Start position in sublist | |
for k in range(len(sublist)-j, 0, -1): # Width of sublist window | |
if k > best_k and list[i:i+k] == sublist[j:j+k]: | |
best_i, best_j, best_k = i, j, k | |
break # Since window size decreases | |
return best_i, best_j, best_k | |
def predict_sponsor_from_texts(texts, model, tokenizer): | |
clean_texts = list(map(preprocess.clean_text, texts)) | |
return predict_sponsor_from_cleaned_texts(clean_texts, model, tokenizer) | |
def predict_sponsor_from_cleaned_texts(cleaned_texts, model, tokenizer): | |
"""Given a body of text, predict the words which are part of the sponsor""" | |
model_device = next(model.parameters()).device | |
decoded_outputs = [] | |
# Do individually, to avoid running out of memory for long videos | |
for cleaned_words in cleaned_texts: | |
text = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value + \ | |
' '.join(cleaned_words) | |
input_ids = tokenizer(text, return_tensors='pt', | |
truncation=True).input_ids.to(model_device) | |
# Optimise output length so that we do not generate unnecessarily long texts | |
max_out_len = round(min( | |
max( | |
len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE, | |
len(input_ids[0]) + MIN_SAFETY_TOKENS | |
), | |
model.model_dim) | |
) | |
outputs = model.generate(input_ids, max_length=max_out_len) | |
decoded_outputs.append(tokenizer.decode( | |
outputs[0], skip_special_tokens=True)) | |
return decoded_outputs | |
def segments_to_predictions(segments, model, tokenizer): | |
predicted_time_ranges = [] | |
cleaned_texts = [ | |
[x['cleaned'] for x in cleaned_segment] | |
for cleaned_segment in segments | |
] | |
sponsorship_texts = predict_sponsor_from_cleaned_texts( | |
cleaned_texts, model, tokenizer) | |
matches = extract_sponsor_matches(sponsorship_texts) | |
for segment_matches, cleaned_batch, segment in zip(matches, cleaned_texts, segments): | |
for match in segment_matches: # one segment might contain multiple sponsors/ir/selfpromos | |
matched_text = match['text'].split() | |
i1, j1, k1 = greedy_match( | |
cleaned_batch, matched_text[:MATCH_WINDOW]) | |
i2, j2, k2 = greedy_match( | |
cleaned_batch, matched_text[-MATCH_WINDOW:]) | |
extracted_words = segment[i1:i2+k2] | |
if not extracted_words: | |
continue | |
predicted_time_ranges.append({ | |
'start': word_start(extracted_words[0]), | |
'end': word_end(extracted_words[-1]), | |
'category': match['category'] | |
}) | |
# Necessary to sort matches by start time | |
predicted_time_ranges.sort(key=word_start) | |
# Merge overlapping predictions and sponsorships that are close together | |
# Caused by model having max input size | |
prev_prediction = None | |
final_predicted_time_ranges = [] | |
for range in predicted_time_ranges: | |
start_time = range['start'] if range['start'] > START_TIME_ZERO_THRESHOLD else 0 | |
end_time = range['end'] | |
if prev_prediction is not None and \ | |
(start_time <= prev_prediction['end'] <= end_time or # Merge overlapping segments | |
(range['category'] == prev_prediction['category'] # Merge disconnected segments if same category and within threshold | |
and start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN)): | |
# Extend last prediction range | |
final_predicted_time_ranges[-1]['end'] = end_time | |
else: # No overlap, is a new prediction | |
final_predicted_time_ranges.append({ | |
'start': start_time, | |
'end': end_time, | |
'category': range['category'] | |
}) | |
prev_prediction = range | |
return final_predicted_time_ranges | |
def main(): | |
# Test on unseen data | |
logger.setLevel(logging.DEBUG) | |
hf_parser = HfArgumentParser(( | |
PredictArguments, | |
SegmentationArguments, | |
GeneralArguments | |
)) | |
predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses() | |
if not predict_args.video_ids: | |
logger.error( | |
'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.') | |
return | |
model, tokenizer, classifier = get_model_tokenizer_classifier( | |
predict_args, general_args) | |
for video_id in predict_args.video_ids: | |
try: | |
predictions = predict(video_id, model, tokenizer, segmentation_args, | |
classifier=classifier, | |
min_probability=predict_args.min_probability) | |
except TranscriptError: | |
logger.warning(f'No transcript available for {video_id}') | |
continue | |
video_url = f'https://www.youtube.com/watch?v={video_id}' | |
if not predictions: | |
logger.info(f'No predictions found for {video_url}') | |
continue | |
# TODO use predict_args.output_as_json | |
print(len(predictions), 'predictions found for', video_url) | |
for index, prediction in enumerate(predictions, start=1): | |
print(f'Prediction #{index}:') | |
print('Text: "', | |
' '.join([w['text'] for w in prediction['words']]), '"', sep='') | |
print('Time:', seconds_to_time( | |
prediction['start']), '\u2192', seconds_to_time(prediction['end'])) | |
print('Category:', prediction.get('category')) | |
if 'probability' in prediction: | |
print('Probability:', prediction['probability']) | |
print() | |
print() | |
if __name__ == '__main__': | |
main() | |