Spaces:
Runtime error
Runtime error
if __name__ == '__main__': | |
import sys | |
from pathlib import Path | |
project_root = Path( | |
__file__).parent.parent.parent.absolute() # /home/adapting/git/leoxiang66/idp_LiteratureResearch_Tool | |
sys.path.append(project_root.__str__()) | |
import torch | |
from lrt.clustering.models.keyBartPlus import * | |
from lrt.clustering.models.adapter import * | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import os | |
####################### Adapter Test ############################# | |
input_dim = 1024 | |
adapter_hid_dim = 256 | |
adapter = Adapter(input_dim,adapter_hid_dim) | |
data = torch.randn(10, 20, input_dim) | |
tmp = adapter(data) | |
assert data.size() == tmp.size() | |
####################### Adapter Test ############################# | |
####################### BartDecoderPlus Test ############################# | |
keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") | |
bartDecoderP = BartDecoderPlus(keyBart, 100) | |
tmp = bartDecoderP(inputs_embeds=data, | |
output_attentions = True, | |
output_hidden_states = True, | |
encoder_hidden_states = data | |
) | |
print(type(tmp)) | |
# print(tmp.__dict__) | |
print(dir(tmp)) | |
last_hid_states = tmp.last_hidden_state | |
hidden_states = tmp.hidden_states | |
attentions = tmp.attentions | |
cross_attention = tmp.cross_attentions | |
print(last_hid_states.shape) | |
print(hidden_states.__len__()) | |
print(attentions.__len__()) | |
print(len(cross_attention)) | |
# print(cross_attention[0]) | |
print(cross_attention[0].shape) | |
####################### BartDecoderPlus Test ############################# | |
####################### BartPlus Test ############################# | |
bartP = BartPlus(keyBart,100) | |
tmp = bartP( | |
inputs_embeds = data, | |
decoder_inputs_embeds = data, | |
output_attentions=True, | |
output_hidden_states=True, | |
) | |
print(type(tmp)) | |
# print(tmp.__dict__) | |
print(dir(tmp)) | |
last_hid_states = tmp.last_hidden_state | |
hidden_states = tmp.decoder_hidden_states | |
attentions = tmp.decoder_attentions | |
cross_attention = tmp.cross_attentions | |
print(last_hid_states.shape) | |
print(hidden_states.__len__()) | |
print(attentions.__len__()) | |
print(len(cross_attention)) | |
# print(cross_attention[0]) | |
print(cross_attention[0].shape) | |
####################### BartPlus Test ############################# | |
####################### Summary ############################# | |
from torchinfo import summary | |
summary(bartP) | |
# summary(bartDecoderP) | |
####################### Summary ############################# | |
####################### KeyBartAdapter Test ############################# | |
keybart_adapter = KeyBartAdapter(100) | |
tmp = keybart_adapter( | |
inputs_embeds=data, | |
decoder_inputs_embeds=data, | |
output_attentions=True, | |
output_hidden_states=True, | |
) | |
print(type(tmp)) | |
# print(tmp.__dict__) | |
print(dir(tmp)) | |
last_hid_states = tmp.encoder_last_hidden_state | |
hidden_states = tmp.decoder_hidden_states | |
attentions = tmp.decoder_attentions | |
cross_attention = tmp.cross_attentions | |
print(last_hid_states.shape) | |
print(hidden_states.__len__()) | |
print(attentions.__len__()) | |
print(len(cross_attention)) | |
# print(cross_attention[0]) | |
print(cross_attention[0].shape) | |
summary(keybart_adapter) | |
####################### KeyBartAdapter Test ############################# |