File size: 598 Bytes
8b414b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from src.data_reader import load_train_test_df
from src.feature_extractors.bert_pretrain_extractor import \
BertPretrainFeatureExtractor
def test_pretrain_feature_extractor():
models = ['distilbert-base-uncased-finetuned-sst-2-english', 'bert-base-uncased']
train_df, _ = load_train_test_df(is_testing=True)
for model_name in models:
feature_extractor = BertPretrainFeatureExtractor(model_name=model_name)
output_features = feature_extractor.generate_features(train_df.full_text)
assert len(output_features) == 5 and len(output_features.columns) == 768
|