unbias-one / load_model_pt.py
Jordan
Unbias - Version one push
10f417b
raw
history blame
617 Bytes
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
def load_pipeline(input_statement, pretrained_model_name):
classifier_ = pipeline("text-classification", model=pretrained_model_name, framework="pt")
cls_output = classifier_(input_statement)[0]
return cls_output
def load_models_from_pretrained(checkpoint):
checkpoint_local = checkpoint
tokenizer = AutoTokenizer.from_pretrained(checkpoint_local)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_local)
return tokenizer, model