26、UFO模塊
論文《UFO-ViT: High Performance Linear Vision Transformer without Softmax》
1、作用
UFO-ViT旨在解決傳統(tǒng)Transformer在視覺任務(wù)中所面臨的主要挑戰(zhàn)之一:SA機(jī)制的計(jì)算資源需求隨輸入尺寸的平方增長,這使得處理高分辨率輸入變得不切實(shí)際。UFO-ViT通過提出一種新的SA機(jī)制,消除了非線性操作,實(shí)現(xiàn)了對計(jì)算復(fù)雜度的線性控制,同時保持了高性能。
2、機(jī)制
1、自注意力機(jī)制的簡化:
UFO-ViT通過消除softmax非線性,簡化了自注意力機(jī)制。它采用簡單的L2范數(shù)替代softmax函數(shù),利用矩陣乘法的結(jié)合律先計(jì)算鍵和值的乘積,然后再與查詢相乘。
2、跨標(biāo)準(zhǔn)化(XNorm):
UFO-ViT引入了一種新的規(guī)范化方法——跨標(biāo)準(zhǔn)化(XNorm),用于替換softmax。這種方法將鍵和值的自注意力乘積直接相乘,并通過線性核方法生成聚類,以線性復(fù)雜度處理自注意力。
3、獨(dú)特優(yōu)勢
1、線性復(fù)雜度:
與傳統(tǒng)的具有O(N^2)復(fù)雜度的SA機(jī)制不同,UFO-ViT的SA機(jī)制具有線性復(fù)雜度,使其能夠有效處理高分辨率輸入。
2、適用于通用目的:
UFO-ViT在圖像分類和密集預(yù)測任務(wù)中的表現(xiàn)證明了其作為通用模型的潛力。盡管復(fù)雜度線性,但UFO-ViT模型在較低容量和FLOPS下仍然超越了大多數(shù)現(xiàn)有的基于Transformer的模型。
3、推理速度快,GPU內(nèi)存需求低:
UFO-ViT模型具有更快的推理速度,并且在各種分辨率下所需的GPU內(nèi)存較少。特別是在處理高分辨率輸入時,所需的計(jì)算資源并沒有顯著增加。
4、代碼
import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init
# 定義XNorm函數(shù),對輸入x進(jìn)行規(guī)范化
def XNorm(x, gamma):
norm_tensor = torch.norm(x, 2, -1, True)
return x * gamma / norm_tensor
# UFOAttention類繼承自nn.Module
class UFOAttention(nn.Module):
'''
實(shí)現(xiàn)一個改進(jìn)的自注意力機(jī)制,具有線性復(fù)雜度。
'''
# 初始化函數(shù)
def __init__(self, d_model, d_k, d_v, h, dropout=.1):
'''
:param d_model: 模型的維度
:param d_k: 查詢和鍵的維度
:param d_v: 值的維度
:param h: 注意力頭數(shù)
'''
super(UFOAttention, self).__init__()
# 初始化四個線性層:為查詢、鍵、值和輸出轉(zhuǎn)換使用
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.dropout = nn.Dropout(dropout)
# gamma參數(shù)用于規(guī)范化
self.gamma = nn.Parameter(torch.randn((1, h, 1, 1)))
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.init_weights()
# 權(quán)重初始化
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
# 前向傳播
def forward(self, queries, keys, values):
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
# 通過線性層將查詢、鍵、值映射到新的空間
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)
# 計(jì)算鍵和值的乘積,然后對結(jié)果進(jìn)行規(guī)范化
kv = torch.matmul(k, v) # bs,h,c,c
kv_norm = XNorm(kv, self.gamma) # bs,h,c,c
q_norm = XNorm(q, self.gamma) # bs,h,n,c
out = torch.matmul(q_norm, kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
if __name__ == '__main__':
# 示例用法
block = UFOAttention(d_model=512, d_k=512, d_v=512, h=8).cuda()
input = torch.rand(64, 64, 512).cuda()
output = block(input, input, input)
print(output.shape)