31、Cloattention模塊
論文《Rethinking Local Perception in Lightweight Vision Transformer》
1、作用
CloFormer(Context-aware Local Enhancement Vision Transformer)是一種輕量級的視覺Transformer,用于在保持模型輕量化的同時(shí),提高在各種視覺任務(wù)中的性能,包括圖像分類、目標(biāo)檢測和語義分割。其主要目的是提升移動設(shè)備上的視覺模型性能,克服直接縮減標(biāo)準(zhǔn)ViT(Vision Transformer)模型尺寸導(dǎo)致的性能下降問題。
2、機(jī)制
CloFormer通過引入AttnConv(Attention Style Convolution Operator)來實(shí)現(xiàn)上下文感知的局部增強(qiáng),從而有效捕獲高頻局部信息。該模型采用兩分支結(jié)構(gòu):
1、局部分支:
利用AttnConv融合共享權(quán)重和上下文感知權(quán)重來聚合高頻局部信息。首先,使用深度可分離卷積(Depthwise Convolution,DWconv)提取局部表示,然后部署上下文感知權(quán)重來增強(qiáng)局部特征。
2、全局分支:
采用標(biāo)準(zhǔn)的注意力機(jī)制,通過對K和V進(jìn)行下采樣來降低FLOPs,幫助模型捕捉低頻全局信息。
3、獨(dú)特優(yōu)勢
1、上下文感知的局部增強(qiáng):
通過AttnConv,CloFormer有效地結(jié)合了共享權(quán)重和上下文感知權(quán)重的優(yōu)勢,實(shí)現(xiàn)了高質(zhì)量的局部特征增強(qiáng)。
2、兩分支結(jié)構(gòu):
通過同時(shí)捕獲高頻和低頻信息,模型能夠在不同的視覺任務(wù)中達(dá)到更好的性能。
3、輕量化設(shè)計(jì):
CloFormer專為移動設(shè)備設(shè)計(jì),通過精心的模型架構(gòu)設(shè)計(jì)和權(quán)重共享機(jī)制,實(shí)現(xiàn)了在保持輕量化的同時(shí)提高模型性能。
4、代碼
import torch
from torch import nn
from efficientnet_pytorch.model import MemoryEfficientSwish
# 定義一個(gè)通過卷積和激活函數(shù)生成注意力圖的模塊
class AttnMap(nn.Module):
def __init__(self, dim):
super().__init__()
self.act_block = nn.Sequential(
nn.Conv2d(dim, dim, 1, 1, 0), # 1x1卷積用于調(diào)整通道數(shù)
MemoryEfficientSwish(), # 使用MemoryEfficientSwish作為激活函數(shù)
nn.Conv2d(dim, dim, 1, 1, 0) # 再次1x1卷積
)
def forward(self, x):
return self.act_block(x)
# 定義高效注意力機(jī)制的主體模塊
class EfficientAttention(nn.Module):
def __init__(self, dim, num_heads=8, group_split=[4, 4], kernel_sizes=[5], window_size=4,
attn_drop=0., proj_drop=0., qkv_bias=True):
super().__init__()
# 參數(shù)初始化和定義
assert sum(group_split) == num_heads # 確保分組數(shù)量之和等于頭的數(shù)量
assert len(kernel_sizes) + 1 == len(group_split) # 確保核大小列表加一等于分組數(shù)量
self.dim = dim # 輸入通道數(shù)
self.num_heads = num_heads # 注意力頭的數(shù)量
self.dim_head = dim // num_heads # 每個(gè)頭的維度
self.scalor = self.dim_head ** -0.5 # 縮放因子
self.kernel_sizes = kernel_sizes # 核大小列表
self.window_size = window_size # 窗口大小
self.group_split = group_split # 分組列表
# 根據(jù)核大小和分組定義卷積層、注意力映射層和QKV層
convs = []
act_blocks = []
qkvs = []
for i in range(len(kernel_sizes)):
kernel_size = kernel_sizes[i]
group_head = group_split[i]
if group_head == 0:
continue
convs.append(nn.Conv2d(3 * self.dim_head * group_head, 3 * self.dim_head * group_head, kernel_size,
1, kernel_size // 2, groups=3 * self.dim_head * group_head))
act_blocks.append(AttnMap(self.dim_head * group_head))
qkvs.append(nn.Conv2d(dim, 3 * group_head * self.dim_head, 1, 1, 0, bias=qkv_bias))
if group_split[-1] != 0:
# 對最后一個(gè)全局注意力頭的定義
self.global_q = nn.Conv2d(dim, group_split[-1] * self.dim_head, 1, 1, 0, bias=qkv_bias)
self.global_kv = nn.Conv2d(dim, group_split[-1] * self.dim_head * 2, 1, 1, 0, bias=qkv_bias)
self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity()
# 將定義的模塊注冊為子模塊
self.convs = nn.ModuleList(convs)
self.act_blocks = nn.ModuleList(act_blocks)
self.qkvs = nn.ModuleList(qkvs)
self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias) # 輸出投影層
self.attn_drop = nn.Dropout(attn_drop) # 注意力dropout
self.proj_drop = nn.Dropout(proj_drop) # 投影dropout
# 高頻注意力處理函數(shù)
def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
'''
x: (b c h w)
'''
b, c, h, w = x.size()
qkv = to_qkv(x) # (b (3 m d) h w)
qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() # (3 b (m d) h w)
q, k, v = qkv # (b (m d) h w)
attn = attn_block(q.mul(k)).mul(self.scalor)
attn = self.attn_drop(torch.tanh(attn))
res = attn.mul(v) # (b (m d) h w)
return res
# 低頻注意力處理函數(shù)
def low_fre_attention(self, x: torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
'''
x: (b c h w)
'''
b, c, h, w = x.size()
q = to_q(x).reshape(b, -1, self.dim_head, h * w).transpose(-1, -2).contiguous() # (b m (h w) d)
kv = avgpool(x) # (b c h w)
kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h * w) // (self.window_size ** 2)).permute(1, 0, 2, 4,
3).contiguous() # (2 b m (H W) d)
k, v = kv # (b m (H W) d)
attn = self.scalor * q @ k.transpose(-1, -2) # (b m (h w) (H W))
attn = self.attn_drop(attn.softmax(dim=-1))
res = attn @ v # (b m (h w) d)
res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()
return res
# 模塊的前向傳播
def forward(self, x: torch.Tensor):
'''
x: (b c h w)
'''
res = []
for i in range(len(self.kernel_sizes)):
if self.group_split[i] == 0:
continue
res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))
if self.group_split[-1] != 0:
res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))
return self.proj_drop(self.proj(torch.cat(res, dim=1)))
# 輸入 N C HW, 輸出 N C H W
if __name__ == '__main__':
block = EfficientAttention(64).cuda() # 實(shí)例化注意力模塊
input = torch.rand(1, 64, 64, 64).cuda()# 創(chuàng)建一個(gè)隨機(jī)輸入
output = block(input)
print(output.shape) # 打印輸出形狀