Spaces:
Runtime error
Runtime error
File size: 3,565 Bytes
16188ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
if __name__ == '__main__':
import sys
from pathlib import Path
project_root = Path(
__file__).parent.parent.parent.absolute() # /home/adapting/git/leoxiang66/idp_LiteratureResearch_Tool
sys.path.append(project_root.__str__())
import torch
from lrt.clustering.models.keyBartPlus import *
from lrt.clustering.models.adapter import *
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
####################### Adapter Test #############################
input_dim = 1024
adapter_hid_dim = 256
adapter = Adapter(input_dim,adapter_hid_dim)
data = torch.randn(10, 20, input_dim)
tmp = adapter(data)
assert data.size() == tmp.size()
####################### Adapter Test #############################
####################### BartDecoderPlus Test #############################
keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART")
bartDecoderP = BartDecoderPlus(keyBart, 100)
tmp = bartDecoderP(inputs_embeds=data,
output_attentions = True,
output_hidden_states = True,
encoder_hidden_states = data
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.last_hidden_state
hidden_states = tmp.hidden_states
attentions = tmp.attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
####################### BartDecoderPlus Test #############################
####################### BartPlus Test #############################
bartP = BartPlus(keyBart,100)
tmp = bartP(
inputs_embeds = data,
decoder_inputs_embeds = data,
output_attentions=True,
output_hidden_states=True,
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.last_hidden_state
hidden_states = tmp.decoder_hidden_states
attentions = tmp.decoder_attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
####################### BartPlus Test #############################
####################### Summary #############################
from torchinfo import summary
summary(bartP)
# summary(bartDecoderP)
####################### Summary #############################
####################### KeyBartAdapter Test #############################
keybart_adapter = KeyBartAdapter(100)
tmp = keybart_adapter(
inputs_embeds=data,
decoder_inputs_embeds=data,
output_attentions=True,
output_hidden_states=True,
)
print(type(tmp))
# print(tmp.__dict__)
print(dir(tmp))
last_hid_states = tmp.encoder_last_hidden_state
hidden_states = tmp.decoder_hidden_states
attentions = tmp.decoder_attentions
cross_attention = tmp.cross_attentions
print(last_hid_states.shape)
print(hidden_states.__len__())
print(attentions.__len__())
print(len(cross_attention))
# print(cross_attention[0])
print(cross_attention[0].shape)
summary(keybart_adapter)
####################### KeyBartAdapter Test ############################# |