File size: 3,565 Bytes
16188ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
if __name__ == '__main__':
    import sys
    from pathlib import Path

    project_root = Path(
        __file__).parent.parent.parent.absolute()  # /home/adapting/git/leoxiang66/idp_LiteratureResearch_Tool
    sys.path.append(project_root.__str__())

    import torch
    from lrt.clustering.models.keyBartPlus import *
    from lrt.clustering.models.adapter import *
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    import os

    ####################### Adapter Test #############################
    input_dim = 1024
    adapter_hid_dim = 256
    adapter = Adapter(input_dim,adapter_hid_dim)

    data = torch.randn(10, 20, input_dim)

    tmp = adapter(data)

    assert data.size() == tmp.size()
    ####################### Adapter Test #############################

    ####################### BartDecoderPlus Test #############################
    keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART")
    bartDecoderP = BartDecoderPlus(keyBart, 100)
    tmp = bartDecoderP(inputs_embeds=data,
                       output_attentions = True,
                       output_hidden_states = True,
                       encoder_hidden_states = data
                       )
    print(type(tmp))
    # print(tmp.__dict__)
    print(dir(tmp))
    last_hid_states = tmp.last_hidden_state
    hidden_states = tmp.hidden_states
    attentions = tmp.attentions
    cross_attention = tmp.cross_attentions
    print(last_hid_states.shape)
    print(hidden_states.__len__())
    print(attentions.__len__())
    print(len(cross_attention))
    # print(cross_attention[0])
    print(cross_attention[0].shape)

    ####################### BartDecoderPlus Test #############################

    ####################### BartPlus Test #############################
    bartP = BartPlus(keyBart,100)
    tmp = bartP(
        inputs_embeds = data,
        decoder_inputs_embeds = data,
        output_attentions=True,
        output_hidden_states=True,
                )
    print(type(tmp))
    # print(tmp.__dict__)
    print(dir(tmp))
    last_hid_states = tmp.last_hidden_state
    hidden_states = tmp.decoder_hidden_states
    attentions = tmp.decoder_attentions
    cross_attention = tmp.cross_attentions
    print(last_hid_states.shape)
    print(hidden_states.__len__())
    print(attentions.__len__())
    print(len(cross_attention))
    # print(cross_attention[0])
    print(cross_attention[0].shape)
    ####################### BartPlus Test #############################

    ####################### Summary #############################
    from torchinfo import summary

    summary(bartP)
    # summary(bartDecoderP)
    ####################### Summary #############################

    ####################### KeyBartAdapter Test #############################
    keybart_adapter = KeyBartAdapter(100)
    tmp = keybart_adapter(
        inputs_embeds=data,
        decoder_inputs_embeds=data,
        output_attentions=True,
        output_hidden_states=True,
    )
    print(type(tmp))
    # print(tmp.__dict__)
    print(dir(tmp))
    last_hid_states = tmp.encoder_last_hidden_state
    hidden_states = tmp.decoder_hidden_states
    attentions = tmp.decoder_attentions
    cross_attention = tmp.cross_attentions
    print(last_hid_states.shape)
    print(hidden_states.__len__())
    print(attentions.__len__())
    print(len(cross_attention))
    # print(cross_attention[0])
    print(cross_attention[0].shape)
    summary(keybart_adapter)
    ####################### KeyBartAdapter Test #############################