mojtaba-nafez's picture
fix config.py
6a63e41
raw
history blame
6.28 kB
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import BertTokenizer, BertModel, BertConfig, BertTokenizerFast
from transformers import XLMRobertaModel, XLMRobertaConfig
import os
"""
Configurations
"""
file_dirname = os.path.dirname(__file__) #in case it is needed for relative paths
dataset_path = os.path.join(file_dirname, "data/Dataset-Merged.json") # dataset path for PoemTextModel training, validation and test
image_path = "" # path to append to the image filenames of datasets used for CLIPModel training
random_seed = 3 # the seed used to shuffle dataset with
# what percentage of dataset will be used for each set?
train_propotion = 0.85
val_propotion = 0.05
# The remaining will be used as the test set
batch_size = 128
num_workers = 0 # parameter of torch Dataloader
lr = 1e-3 # learning rate
weight_decay = 1e-3
patience = 2 # patience parameter for lr scheduler
factor = 0.5 # factor parameter for lr scheduler
epochs = 60
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pretrained hugging face models chosen by poem_encoder_model
poem_encoder_dict = {
"Bert":{
"poem_encoder_pretrained_name": 'mitra-mir/BERT-Persian-Poetry',
},
"ALBERT":{
"poem_encoder_pretrained_name": 'mitra-mir/ALBERT-Persian-Poetry',
},
"ParsBERT":{
"poem_encoder_pretrained_name": 'HooshvareLab/bert-base-parsbert-uncased',
},
}
poem_encoder_model = "ParsBERT" ### Important! The base model for poem encoder (one of "Bert", "ALBERT" and "ParsBERT")
# keep this an empty string if you want to use the pretrained weights from
# huggingface (poem_encoder_dict[poem_encoder_model])/a fresh model.
# else give the path to encoder
poem_encoder_load_path = ""
# path to save encoder to
poem_encoder_save_path = "{}-poem-encoder".format(poem_encoder_model)
if poem_encoder_load_path:
poem_encoder_pretrained_name = poem_encoder_load_path
poem_tokenizer = poem_encoder_load_path
else:
poem_encoder_pretrained_name = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
poem_tokenizer = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
poem_embedding = 768 # embedding dim of poem encoder's output (for one token)
poems_max_length = 64 # max_length parameter when padding/truncating poems using poem tokenizer
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
poem_projection_load_path = os.path.join(file_dirname, "projections/{}_best_poem_projection.pt".format(poem_encoder_model))
# path to save projection to
poem_projection_save_path = "{}_best_poem_projection.pt".format(poem_encoder_model)
poem_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
# Pretrained hugging face models chosen by text_encoder_model
text_encoder_dict = {
"M-Bert":{
"text_encoder_pretrained_name": 'bert-base-multilingual-cased',
},
"XLM-RoBERTa":{
"text_encoder_pretrained_name": 'xlm-roberta-base',
},
"LaBSE":{
"text_encoder_pretrained_name": 'setu4993/LaBSE',
}
}
text_encoder_model = 'LaBSE' ### Important! The base model for text encoder (one of "M-Bert", "XLM-RoBERTa" and "LaBSE")
# keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
text_encoder_load_path = ""
# path to save encoder to
text_encoder_save_path = "{}-text-encoder".format(text_encoder_model)
if text_encoder_load_path:
text_encoder_pretrained_name = text_encoder_load_path
text_tokenizer = text_encoder_load_path
else:
text_encoder_pretrained_name = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
text_tokenizer = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
text_embedding = 768 # embedding dim of text encoder's output (for one token)
text_max_length = 200 # max_length parameter when padding/truncating text using text tokenizer
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
text_projection_load_path = os.path.join(file_dirname, "projections/{}_best_text_projection.pt".format(text_encoder_model))
# path to save peojection to
text_projection_save_path = "{}_best_text_projection.pt".format(text_encoder_model)
text_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
image_encoder_model = 'resnet50' # image model name to load via timm library
# keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
image_encoder_weights_load_path = ""
# path to save encoder weights to
image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
image_embedding = 2048 # embedding dim of image encoder's output (for one token)
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
image_projection_load_path = ""
# path to save projection to
image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
# classes of Tokenizer, Model and Config to use for each text/poem encoder model
tokenizers = {"ALBERT": AutoTokenizer, "M-Bert": BertTokenizer, "XLM-RoBERTa": AutoTokenizer, "ParsBERT":AutoTokenizer, "Bert":AutoTokenizer, "LaBSE": BertTokenizerFast}
encoders = {"ALBERT": AutoModel, "M-Bert": BertModel, "XLM-RoBERTa":XLMRobertaModel, "ParsBERT": AutoModel, "Bert":AutoModel, "LaBSE": BertModel}
configs = {"ALBERT": AutoConfig, "M-Bert": BertConfig, "XLM-RoBERTa": XLMRobertaConfig, "ParsBERT": AutoConfig, "Bert":AutoConfig, "LaBSE": BertConfig}
temperature = 1.0 # temperature parameter for scaling dot similarities
# image size
size = 224
# for projection head; used for poem, text and image encoders
projection_dim = 1024 # projection embedding dim (output of models dim)
dropout = 0.1 # fraction of the output of fc layer in projection head to be zeroed.