File size: 598 Bytes
8b414b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from src.data_reader import load_train_test_df
from src.feature_extractors.bert_pretrain_extractor import \
    BertPretrainFeatureExtractor


def test_pretrain_feature_extractor():
    models = ['distilbert-base-uncased-finetuned-sst-2-english', 'bert-base-uncased']
    train_df, _ = load_train_test_df(is_testing=True)

    for model_name in models:
        feature_extractor = BertPretrainFeatureExtractor(model_name=model_name)

        output_features = feature_extractor.generate_features(train_df.full_text)

        assert len(output_features) == 5 and len(output_features.columns) == 768