linguask / tests /test_feature_extractor.py
GitHub Action
refs/heads/ci-cd/hugging-face
8b414b0
raw
history blame contribute delete
598 Bytes
from src.data_reader import load_train_test_df
from src.feature_extractors.bert_pretrain_extractor import \
BertPretrainFeatureExtractor
def test_pretrain_feature_extractor():
models = ['distilbert-base-uncased-finetuned-sst-2-english', 'bert-base-uncased']
train_df, _ = load_train_test_df(is_testing=True)
for model_name in models:
feature_extractor = BertPretrainFeatureExtractor(model_name=model_name)
output_features = feature_extractor.generate_features(train_df.full_text)
assert len(output_features) == 5 and len(output_features.columns) == 768