research / scripts /tests /model_test.py
haoqi7's picture
Upload 47 files
16188ba
raw
history blame
3.57 kB
if __name__ == '__main__':
import sys
from pathlib import Path
project_root = Path(
__file__).parent.parent.parent.absolute() # /home/adapting/git/leoxiang66/idp_LiteratureResearch_Tool
sys.path.append(project_root.__str__())
import torch
from lrt.clustering.models.keyBartPlus import *
from lrt.clustering.models.adapter import *
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
####################### Adapter Test #############################
input_dim = 1024
adapter_hid_dim = 256
adapter = Adapter(input_dim,adapter_hid_dim)
data = torch.randn(10, 20, input_dim)
tmp = adapter(data)
assert data.size() == tmp.size()
####################### Adapter Test #############################
####################### BartDecoderPlus Test #############################
keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART")
bartDecoderP = BartDecoderPlus(keyBart, 100)
tmp = bartDecoderP(inputs_embeds=data,
output_attentions = True,
output_hidden_states = True,
encoder_hidden_states = data
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.last_hidden_state
hidden_states = tmp.hidden_states
attentions = tmp.attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
####################### BartDecoderPlus Test #############################
####################### BartPlus Test #############################
bartP = BartPlus(keyBart,100)
tmp = bartP(
inputs_embeds = data,
decoder_inputs_embeds = data,
output_attentions=True,
output_hidden_states=True,
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.last_hidden_state
hidden_states = tmp.decoder_hidden_states
attentions = tmp.decoder_attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
####################### BartPlus Test #############################
####################### Summary #############################
from torchinfo import summary
summary(bartP)
# summary(bartDecoderP)
####################### Summary #############################
####################### KeyBartAdapter Test #############################
keybart_adapter = KeyBartAdapter(100)
tmp = keybart_adapter(
inputs_embeds=data,
decoder_inputs_embeds=data,
output_attentions=True,
output_hidden_states=True,
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.encoder_last_hidden_state
hidden_states = tmp.decoder_hidden_states
attentions = tmp.decoder_attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
summary(keybart_adapter)
####################### KeyBartAdapter Test #############################