zwt123home123's picture
Upload folder using huggingface_hub
0580d34 verified
import torch
import numpy as np
mask_lst=[]
mask_total = np.zeros((32, 32))
for layer_idx in range(32):
attn_weights_lst=[]
for i in range(4):
attn_weights_lst.append(torch.load("weights"+str(i)+"/"+str(layer_idx)+".pth"))
attn_weights = torch.cat(attn_weights_lst,dim=0)
for i in range(100):
res_v = torch.sum(attn_weights[i,:,:,35:35+576],dim=[1,2])
res_t = torch.sum(attn_weights[i,:,:,35+576:],dim=[1,2])
res_s = torch.sum(attn_weights[i,:,:,:35],dim=[1,2])
res = res_v/(res_t+res_s)
#mask = res>0.35 # headcut5
if layer_idx>=2:
mask = res>0.38 # headcut8: 0.38 headcut16: 0.11 headcut24: 0.044
else:
mask = res>=0
# mask = res>0.06 # headcut20
# mask = res>0.03 # headcut30
mask = mask.int()
# torch.save(mask,"temp/"+str(layer_idx)+".pth")
mask_total[layer_idx]+=mask.detach().cpu().numpy()
#import pdb; pdb.set_trace()
#mask_lst.append(mask)
#mask_lst.append(mask.sum().detach().cpu().numpy())
temp = mask_total>40 #headcut5 10 30
import pdb; pdb.set_trace()
print(mask_lst)
#print(np.mean(mask_lst))
'''
attn_weights_lst=[]
for i in range(4):
attn_weights_lst
for j in range(40):
attn_weights_lst.append(torch.load("weights"+))
attn_weights = torch.cat([attnweight1,attnweight2],dim=0)
res_v = torch.sum(attn_weights[:,:,:,35:35+576],dim=[0,2,3])
res_t = torch.sum( attn_weights[:,:,:,35+576:],dim=[0,2,3])
res_s = torch.sum( attn_weights[:,:,:,:35],dim=[0,2,3])
res = res_v/(res_t+res_s)
mask = res>0.12
import pdb; pdb.set_trace()
'''