import torch def get_text(title: str, abstract: str): if abstract and title: text = abstract + ' ' + title elif title: text = title else: text = None return text def get_labels(text, model, tokenizer, count_labels=8): tokens = tokeinizer(text, return_tensors='pt') outputs = model(**tokens) probs = torch.nn.Softmax(dim=count_labels)(outputs.logits) labels = ['Computer_science', 'Economics', 'Electrical_Engineering_and_Systems_Science', 'Mathematics', 'Physics', 'Quantitative_Biology', 'Quantitative_Finance', 'Statistics'] sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0]) cumsum = 0 result_labels = [] for pair in sort_lst: cumsum += pair[0] if cumsum > 0.95 and len(result_labels) >= 1: return result_labels result_labels.append(pair[1])