File size: 3,251 Bytes
59a9ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch

orig_t1d  = 22
#orig_t1d  = 23
orig_t2d  = 44
orig_dtor = 30

new_t1d   = orig_t1d + 2 + 4 + 1 #(+2 time step and seq confidence) + 4 (dssp) + 1 (hot spot) 
new_t2d   = orig_t2d + 0

#ckpt      = torch.load('/net/scratch/jgershon/models/autofold4_seq2str_base.pt', map_location=torch.device('cpu'))
#ckpt       = torch.load('/home/jgershon/projects/c6d_diff/BFF/autofold4/experiments/2209235seqdiffV4_accum4_str5_aa5_continued/models/BFF_4.pt', map_location=torch.device('cpu'))
ckpt = torch.load('/net/scratch/lisanza/diffuse_3track_fullcon/models/BFF_last.pt', map_location=torch.device('cpu'))

weights   = ckpt['model_state_dict']

print("original weights")
print('templ_emb.emb.weight', weights['templ_emb.emb.weight'].shape)
print('templ_emb.emb_t1d.weight', weights['templ_emb.emb_t1d.weight'].shape)

# weights['templ_emb.emb.weight'] # Original shape: (64, 88)
# weights['templ_emb.emb_t1d.weight'] # Original shape: (64, 52)

# Adding 2D embedding features
# d_t1d*2+d_t2d
if True:
    #pt1_add_dim     = new_t2d - orig_t2d
    pt2_add_dim     = new_t1d - orig_t1d
    pt3_add_dim     = new_t1d - orig_t1d 
    
    #pt1_emb_zeros   = torch.zeros(64, pt1_add_dim)
    pt2_emb_zeros   = torch.zeros(64, pt2_add_dim)
    pt3_emb_zeros   = torch.zeros(64, pt3_add_dim)

    '''
    The way that the t2d input to embedding is created is not straightforward
    It looks like this:

        # Prepare 2D template features
        left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
        right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)

        templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 109)
        templ = self.emb(templ) # Template templures (B, T, L, L, d_templ)
    '''

    #new_emb_weights = torch.cat( (weights['templ_emb.emb.weight'][:,:orig_t2d], pt1_emb_zeros), dim=-1 )
    #new_emb_weights = torch.cat( (new_emb_weights, weights['templ_emb.emb.weight'][:,orig_t2d:orig_t2d+orig_t1d], pt2_emb_zeros), dim=-1 )
    new_emb_weights = torch.cat( (weights['templ_emb.emb.weight'][:,:orig_t2d+orig_t1d], pt2_emb_zeros), dim=-1 )
    new_emb_weights = torch.cat( (new_emb_weights, weights['templ_emb.emb.weight'][:,orig_t2d+orig_t1d:], pt3_emb_zeros), dim=-1 )

    #new_emb_weights = torch.cat( (pt1_emb_weights, pt2_emb_weights, pt3_emb_weights), dim=-1 )
    
# Adding 1D embedding features
# d_t1d+d_tor
if True:
    t1d_weights_dim = new_t1d + orig_dtor
    t1d_add_dim     = t1d_weights_dim - weights['templ_emb.emb_t1d.weight'].shape[1] #52
    
    t1d_zeros       = torch.zeros(64, t1d_add_dim)
    new_t1d_weights = torch.cat( (weights['templ_emb.emb_t1d.weight'][:,:orig_t1d], t1d_zeros), dim=-1 )
    new_t1d_weights = torch.cat( (new_t1d_weights, weights['templ_emb.emb_t1d.weight'][:,orig_t1d:]), dim=-1 )

weights['templ_emb.emb.weight']     = new_emb_weights
weights['templ_emb.emb_t1d.weight'] = new_t1d_weights
print("new t1d weights dim")
print(new_t1d_weights.shape)

ckpt['model_state_dict'] = weights
#torch.save(ckpt, './models/t1d_23_t2d_44_BFF_last.pt')
#torch.save(ckpt, './models/t1d_24_t2d_44_BFF_last.pt')
torch.save(ckpt, '/net/scratch/lisanza/projects/diffusion/models/t1d_29_t2d_44_BFF_SE3big_2.pt')
#torch.save(ckpt, './models/t1d_24_t2d_44_BFF_diffV4_accum4_str5_aa5_continued.pt')