算法1:Attention with DropKey code |
# N: token number, D: token dim # Q: query (N, D), K: key (N, D), V: value (N, D) # use_DropKey: whether use DropKey # mask_ratio: ratio to mask def Attention(Q, K, V, use_DropKey, mask_ratio) attn = (Q * (Q.shape[1] ** -0.5)) @ K.transpose(-2, -1) # use DropKey as regularizer if use_DropKey == True: m_r = torch.ones_like(attn) * mask_ratio attn = attn + torch.bernoulli(m_r) * -1e-12 attn = attn.softmax(dim=-1) x = attn @ V return x |