算法1:Attention with DropKey code

# N: token number, D: token dim

# Q: query (N, D), K: key (N, D), V: value (N, D)

# use_DropKey: whether use DropKey

# mask_ratio: ratio to mask

def Attention(Q, K, V, use_DropKey, mask_ratio)

attn = (Q * (Q.shape[1] ** -0.5)) @ K.transpose(-2, -1)

# use DropKey as regularizer

if use_DropKey == True:

m_r = torch.ones_like(attn) * mask_ratio

attn = attn + torch.bernoulli(m_r) * -1e-12

attn = attn.softmax(dim=-1)

x = attn @ V

return x