|
import torch |
|
import numpy as np |
|
mask_lst=[] |
|
|
|
mask_total = np.zeros((32, 32)) |
|
for layer_idx in range(32): |
|
attn_weights_lst=[] |
|
for i in range(4): |
|
attn_weights_lst.append(torch.load("weights"+str(i)+"/"+str(layer_idx)+".pth")) |
|
|
|
attn_weights = torch.cat(attn_weights_lst,dim=0) |
|
|
|
for i in range(100): |
|
res_v = torch.sum(attn_weights[i,:,:,35:35+576],dim=[1,2]) |
|
res_t = torch.sum(attn_weights[i,:,:,35+576:],dim=[1,2]) |
|
res_s = torch.sum(attn_weights[i,:,:,:35],dim=[1,2]) |
|
res = res_v/(res_t+res_s) |
|
|
|
if layer_idx>=2: |
|
mask = res>0.38 |
|
else: |
|
mask = res>=0 |
|
|
|
|
|
mask = mask.int() |
|
|
|
mask_total[layer_idx]+=mask.detach().cpu().numpy() |
|
|
|
|
|
|
|
|
|
temp = mask_total>40 |
|
|
|
import pdb; pdb.set_trace() |
|
print(mask_lst) |
|
|
|
''' |
|
attn_weights_lst=[] |
|
for i in range(4): |
|
attn_weights_lst |
|
for j in range(40): |
|
attn_weights_lst.append(torch.load("weights"+)) |
|
attn_weights = torch.cat([attnweight1,attnweight2],dim=0) |
|
res_v = torch.sum(attn_weights[:,:,:,35:35+576],dim=[0,2,3]) |
|
res_t = torch.sum( attn_weights[:,:,:,35+576:],dim=[0,2,3]) |
|
res_s = torch.sum( attn_weights[:,:,:,:35],dim=[0,2,3]) |
|
res = res_v/(res_t+res_s) |
|
mask = res>0.12 |
|
import pdb; pdb.set_trace() |
|
''' |
|
|