from src.data_reader import load_train_test_df | |
from src.feature_extractors.bert_pretrain_extractor import \ | |
BertPretrainFeatureExtractor | |
def test_pretrain_feature_extractor(): | |
models = ['distilbert-base-uncased-finetuned-sst-2-english', 'bert-base-uncased'] | |
train_df, _ = load_train_test_df(is_testing=True) | |
for model_name in models: | |
feature_extractor = BertPretrainFeatureExtractor(model_name=model_name) | |
output_features = feature_extractor.generate_features(train_df.full_text) | |
assert len(output_features) == 5 and len(output_features.columns) == 768 | |