from functools import lru_cache | |
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification | |
def get_lm(lm_name): | |
return AutoModelForCausalLM.from_pretrained(lm_name) | |
def get_cls(cls_name): | |
return AutoModelForSequenceClassification.from_pretrained(cls_name) | |