深度學(xué)習(xí)模塊27-UFO模塊

26、UFO模塊

論文《UFO-ViT: High Performance Linear Vision Transformer without Softmax》

1、作用

UFO-ViT旨在解決傳統(tǒng)Transformer在視覺任務(wù)中所面臨的主要挑戰(zhàn)之一:SA機(jī)制的計(jì)算資源需求隨輸入尺寸的平方增長,這使得處理高分辨率輸入變得不切實(shí)際。UFO-ViT通過提出一種新的SA機(jī)制,消除了非線性操作,實(shí)現(xiàn)了對計(jì)算復(fù)雜度的線性控制,同時保持了高性能。

2、機(jī)制

1、自注意力機(jī)制的簡化

UFO-ViT通過消除softmax非線性,簡化了自注意力機(jī)制。它采用簡單的L2范數(shù)替代softmax函數(shù),利用矩陣乘法的結(jié)合律先計(jì)算鍵和值的乘積,然后再與查詢相乘。

2、跨標(biāo)準(zhǔn)化(XNorm)

UFO-ViT引入了一種新的規(guī)范化方法——跨標(biāo)準(zhǔn)化(XNorm),用于替換softmax。這種方法將鍵和值的自注意力乘積直接相乘,并通過線性核方法生成聚類,以線性復(fù)雜度處理自注意力。

3、獨(dú)特優(yōu)勢

1、線性復(fù)雜度

與傳統(tǒng)的具有O(N^2)復(fù)雜度的SA機(jī)制不同,UFO-ViT的SA機(jī)制具有線性復(fù)雜度,使其能夠有效處理高分辨率輸入。

2、適用于通用目的

UFO-ViT在圖像分類和密集預(yù)測任務(wù)中的表現(xiàn)證明了其作為通用模型的潛力。盡管復(fù)雜度線性,但UFO-ViT模型在較低容量和FLOPS下仍然超越了大多數(shù)現(xiàn)有的基于Transformer的模型。

3、推理速度快,GPU內(nèi)存需求低

UFO-ViT模型具有更快的推理速度,并且在各種分辨率下所需的GPU內(nèi)存較少。特別是在處理高分辨率輸入時,所需的計(jì)算資源并沒有顯著增加。

4、代碼

import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init

# 定義XNorm函數(shù),對輸入x進(jìn)行規(guī)范化
def XNorm(x, gamma):
    norm_tensor = torch.norm(x, 2, -1, True)
    return x * gamma / norm_tensor

# UFOAttention類繼承自nn.Module
class UFOAttention(nn.Module):
    '''
    實(shí)現(xiàn)一個改進(jìn)的自注意力機(jī)制,具有線性復(fù)雜度。
    '''

    # 初始化函數(shù)
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):
        '''
        :param d_model: 模型的維度
        :param d_k: 查詢和鍵的維度
        :param d_v: 值的維度
        :param h: 注意力頭數(shù)
        '''
        super(UFOAttention, self).__init__()
        # 初始化四個線性層:為查詢、鍵、值和輸出轉(zhuǎn)換使用
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout = nn.Dropout(dropout)
        # gamma參數(shù)用于規(guī)范化
        self.gamma = nn.Parameter(torch.randn((1, h, 1, 1)))

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()

    # 權(quán)重初始化
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    # 前向傳播
    def forward(self, queries, keys, values):
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        # 通過線性層將查詢、鍵、值映射到新的空間
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)

        # 計(jì)算鍵和值的乘積,然后對結(jié)果進(jìn)行規(guī)范化
        kv = torch.matmul(k, v)  # bs,h,c,c
        kv_norm = XNorm(kv, self.gamma)  # bs,h,c,c
        q_norm = XNorm(q, self.gamma)  # bs,h,n,c
        out = torch.matmul(q_norm, kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)

        return out

if __name__ == '__main__':
    # 示例用法
    block = UFOAttention(d_model=512, d_k=512, d_v=512, h=8).cuda()
    input = torch.rand(64, 64, 512).cuda()
    output = block(input, input, input)
    print(output.shape)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容