attention mechanisms in CV

主要是在视觉领域以及二维的feature map上的注意力机制,不同于1D数据,一般不会用q,k,v来算.总结一下cv中attention的发展.

Coordinate Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#   #!/usr/bin/env python
# #-*- coding:utf-8 -*-
# Copyleft (C) 2024 proanimer, Inc. All Rights Reserved
# author:proanimer
# createTime:2024/2/24 下午4:49
# lastModifiedTime:2024/2/24 下午4:49
# file:coordinate_attention.py
# software: classicNets
#

import torch
import torch.nn as nn
import math
import torch.nn.functional as F


class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)

def forward(self, x):
return self.relu(x + 3) / 6


class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)

def forward(self, x):
return x * self.sigmoid(x)


class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))

mip = max(8, inp // reduction)

self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()

self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

def forward(self, x):
identity = x

n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)

y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)

x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)

a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()

out = identity * a_w * a_h

return out

Deformable Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#   #!/usr/bin/env python
# #-*- coding:utf-8 -*-
# Copyleft (C) 2024 proanimer, Inc. All Rights Reserved
# author:proanimer
# createTime:2024/3/4 下午10:51
# lastModifiedTime:2024/3/4 下午10:50
# file:deformable_attention.py
# software: classicNets
#
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat


# helper functions


def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


def divisible_by(numer, denom):
return (numer % denom) == 0


# tensor helpers


def create_grid_like(t, dim=0):
h, w, device = *t.shape[-2:], t.device

grid = torch.stack(
torch.meshgrid(
torch.arange(w, device=device),
torch.arange(h, device=device),
indexing="xy",
),
dim=dim,
)

grid.requires_grad = False
grid = grid.type_as(t)
return grid


def normalize_grid(grid, dim=1, out_dim=-1):
# normalizes a grid to range from -1 to 1
h, w = grid.shape[-2:]
grid_h, grid_w = grid.unbind(dim=dim)

grid_h = 2.0 * grid_h / max(h - 1, 1) - 1.0
grid_w = 2.0 * grid_w / max(w - 1, 1) - 1.0

return torch.stack((grid_h, grid_w), dim=out_dim)


class Scale(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale

def forward(self, x):
return x * self.scale


# continuous positional bias from SwinV2


class CPB(nn.Module):
"""https://arxiv.org/abs/2111.09883v1"""

def __init__(self, dim, *, heads, offset_groups, depth):
super().__init__()
self.heads = heads
self.offset_groups = offset_groups

self.mlp = nn.ModuleList([])

self.mlp.append(nn.Sequential(nn.Linear(2, dim), nn.ReLU()))

for _ in range(depth - 1):
self.mlp.append(nn.Sequential(nn.Linear(dim, dim), nn.ReLU()))

self.mlp.append(nn.Linear(dim, heads // offset_groups))

def forward(self, grid_q, grid_kv):
device, dtype = grid_q.device, grid_kv.dtype

grid_q = rearrange(grid_q, "h w c -> 1 (h w) c")
grid_kv = rearrange(grid_kv, "b h w c -> b (h w) c")

pos = rearrange(grid_q, "b i c -> b i 1 c") - rearrange(
grid_kv, "b j c -> b 1 j c"
)
bias = torch.sign(pos) * torch.log(
pos.abs() + 1
) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

for layer in self.mlp:
bias = layer(bias)

bias = rearrange(bias, "(b g) i j o -> b (g o) i j", g=self.offset_groups)

return bias


# main class


class DeformableAttention2D(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
dropout=0.0,
downsample_factor=4,
offset_scale=None,
offset_groups=None,
offset_kernel_size=6,
group_queries=True,
group_key_values=True
):
super().__init__()
offset_scale = default(offset_scale, downsample_factor)
assert (
offset_kernel_size >= downsample_factor
), "offset kernel size must be greater than or equal to the downsample factor"
assert divisible_by(offset_kernel_size - downsample_factor, 2)

offset_groups = default(offset_groups, heads)
assert divisible_by(heads, offset_groups)

inner_dim = dim_head * heads
self.scale = dim_head**-0.5
self.heads = heads
self.offset_groups = offset_groups

offset_dims = inner_dim // offset_groups

self.downsample_factor = downsample_factor

self.to_offsets = nn.Sequential(
nn.Conv2d(
offset_dims,
offset_dims,
offset_kernel_size,
groups=offset_dims,
stride=downsample_factor,
padding=(offset_kernel_size - downsample_factor) // 2,
),
nn.GELU(),
nn.Conv2d(offset_dims, 2, 1, bias=False),
nn.Tanh(),
Scale(offset_scale),
)

self.rel_pos_bias = CPB(
dim // 4, offset_groups=offset_groups, heads=heads, depth=2
)

self.dropout = nn.Dropout(dropout)
self.to_q = nn.Conv2d(
dim, inner_dim, 1, groups=offset_groups if group_queries else 1, bias=False
)
self.to_k = nn.Conv2d(
dim,
inner_dim,
1,
groups=offset_groups if group_key_values else 1,
bias=False,
)
self.to_v = nn.Conv2d(
dim,
inner_dim,
1,
groups=offset_groups if group_key_values else 1,
bias=False,
)
self.to_out = nn.Conv2d(inner_dim, dim, 1)

def forward(self, x, return_vgrid=False):
"""
b - batch
h - heads
x - height
y - width
d - dimension
g - offset groups
"""

heads, b, h, w, downsample_factor, device = (
self.heads,
x.shape[0],
*x.shape[-2:],
self.downsample_factor,
x.device,
)

# queries

q = self.to_q(x)

# calculate offsets - offset MLP shared across all groups

group = lambda t: rearrange(
t, "b (g d) ... -> (b g) d ...", g=self.offset_groups
)

grouped_queries = group(q)
offsets = self.to_offsets(grouped_queries)

# calculate grid + offsets

grid = create_grid_like(offsets)
vgrid = grid + offsets

vgrid_scaled = normalize_grid(vgrid)

kv_feats = F.grid_sample(
group(x),
vgrid_scaled,
mode="bilinear",
padding_mode="zeros",
align_corners=False,
)

kv_feats = rearrange(kv_feats, "(b g) d ... -> b (g d) ...", b=b)

# derive key / values

k, v = self.to_k(kv_feats), self.to_v(kv_feats)

# scale queries

q = q * self.scale

# split out heads

q, k, v = map(
lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=heads), (q, k, v)
)

# query / key similarity

sim = einsum("b h i d, b h j d -> b h i j", q, k)

# relative positional bias

grid = create_grid_like(x)
grid_scaled = normalize_grid(grid, dim=0)
rel_pos_bias = self.rel_pos_bias(grid_scaled, vgrid_scaled)
sim = sim + rel_pos_bias

# numerical stability

sim = sim - sim.amax(dim=-1, keepdim=True).detach()

# attention

attn = sim.softmax(dim=-1)
attn = self.dropout(attn)

# aggregate and combine heads

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
out = self.to_out(out)

if return_vgrid:
return out, vgrid

return out

BAM

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
from torch.nn import init

class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)



class ChannelAttention(nn.Module):
def __init__(self,channel,reduction:int=16,num_layer:int=3):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
gate_channels = [channel]
gate_channels += [channel // reduction] * num_layer
gate_channels += [channel]

self.ca = nn.Sequential()
self.ca.add_module('flatten',Flatten())
for i in range(num_layer):
self.ca.add_module('fc{}'.format(i),nn.Linear(gate_channels[i],gate_channels[i+1]))
self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i+1]))
self.ca.add_module('relu{}'.format(i),nn.ReLU())

self.ca.add_module('last_fc',nn.Linear(gate_channels[-2],gate_channels[-1]))

def forward(self,x):
res = self.avg_pool(x)
res = self.ca(res)
return res.unsqueeze(-1).unsqueeze(-1).expand_as(x)

class SpatialAttention(nn.Module):
def __init__(self,channel,reduction=16,num_layers=3,dia_val=2):
super().__init__()
self.sa = nn.Sequential()
self.sa.add_module('conv_reduce1',nn.Conv2d(in_channels=channel,out_channels=channel//reduction,kernel_size=1))
self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction))
self.sa.add_module('relu_reduce1',nn.ReLU())
for i in range(num_layers):
self.sa.add_module('conv_%d' % i,nn.Conv2d(in_channels=channel//reduction,out_channels=channel//reduction,kernel_size=3,padding=1,dilation=dia_val))
self.sa.add_module('bn_%d' % i,nn.BatchNorm2d(channel//reduction))
self.sa.add_module('relu_%d' % i,nn.ReLU())
self.sa.add_module('conv_last',nn.Conv2d(in_channels=channel//reduction,out_channels=channel,kernel_size=1))

def forward(self,x):
res = self.sa(x)

return res.expand_as(x)



class BAMBlock(nn.Module):
def __init__(self,channel:int=512,reduction:int=16,dia_val:int=2):
super().__init__()
self.ca = ChannelAttention(channel=channel,reduction=reduction)
self.sa = SpatialAttention(channel=channel,reduction=reduction,dia_val=dia_val)
self.sigmoid = nn.Sigmoid()
self.init_weights()

def forward(self,x):
b, c, _, _ = x.size()
sa_out = self.sa(x)
ca_out = self.ca(x)
weight = self.sigmoid(sa_out + ca_out)
out = (1 + weight) * x
return out


def init_weights(self):
# initial weights for the model
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m,nn.BatchNorm2d):
init.constant_(m.weight,1)
init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
init.normal_(m.weight,std=0.001)
if m.bias is not None:
init.constant_(m.bias,0)

CBAM

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
def __init__(self, input_dim, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc0 = nn.Linear(input_dim, input_dim // ratio, bias=False)
self.fc1 = nn.Linear(input_dim // ratio, input_dim, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
avg_out = self.fc1(self.fc0(self.avg_pool(x).squeeze()).unsqueeze(2)).unsqueeze(3)
max_out = self.fc1(self.fc0(self.max_pool(x).squeeze()).unsqueeze(2)).unsqueeze(3)
out = self.relu(torch.cat([avg_out, max_out], dim=2))
return out


class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
return self.sigmoid(out)

SKnet

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Optional, Union
from collections import OrderedDict
import torch
import torch.nn as nn

class SKConv(nn.Module):
"""
https://arxiv.org/pdf/1903.06586.pdf
"""

def __init__(self, feature_dim, WH, M, G, r, stride=1, L=32):

""" Constructor
Args:
features: input channel dimensionality.
WH: input spatial dimensionality, used for GAP kernel size.
M: the number of branchs.
G: num of convolution groups.
r: the radio for compute d, the length of z.
stride: stride, default 1.
L: the minimum dim of the vector z in paper, default 32.
"""
super().__init__()
d = max(int(feature_dim / r), L)
self.M = M
self.feature_dim = feature_dim
self.convs = nn.ModuleList()
for i in range(M):
self.convs.append(nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=False)
))
self.gap = nn.AdaptiveAvgPool2d(int(WH/stride))
self.fc = nn.Linear(feature_dim, d)
self.fcs = nn.ModuleList()
for i in range(M):
self.fcs.append(
nn.Linear(d, feature_dim)
)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
for i, conv in enumerate(self.convs):
feat = conv(x).unsqueeze_(dim=1)
if i == 0:
feas = feat
else:
feas = torch.cat((feas, feat), dim=1)

fea_U = torch.sum(feas, dim=1)
fea_s = self.gap(fea_U).squeeze_()
fea_z = self.fc(fea_s)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat((attention_vectors, vector), dim=1)
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
fea_v = (feas*attention_vectors).sum(dim=1)
return fea_v

class SKUnit(nn.Module):
def __init__(self, in_features, out_features, WH, M, G, r, mid_features=None, stride=1, L=32):
""" Constructor
Args:
in_features: input channel dimensionality.
out_features: output channel dimensionality.
WH: input spatial dimensionality, used for GAP kernel size.
M: the number of branchs.
G: num of convolution groups.
r: the radio for compute d, the length of z.
mid_features: the channle dim of the middle conv with stride not 1, default out_features/2.
stride: stride.
L: the minimum dim of the vector z in paper.

————————————————
"""
super().__init__()
super().__int__()
if mid_features is None:
mid_features = int(out_features//2)
self.feas = nn.Sequential(
nn.Conv2d(in_features, mid_features, 1),
nn.BatchNorm2d(mid_features),
SKConv(mid_features, WH, M, G, r, stride, L),
nn.BatchNorm2d(mid_features),
nn.Conv2d(mid_features, out_features, 1),
nn.BatchNorm2d(out_features)
)

if in_features == out_features:
self.shortcut = nn.Sequential()
else:
self.shortcut = nn.Sequential(
nn.Conv2d(in_features, out_features, 1),
nn.BatchNorm2d(out_features)
)

def forward(self,x):
fea = self.feas(x)
return fea + self.shortcut(x)

class SKNet(nn.Module):
def __init__(self,class_num):
super().__init__()
self.basic_conv = nn.Sequential(
nn.Conv2d(3,64,3),
nn.BatchNorm2d(64)
)
self.stage_1 = nn.Sequential(
SKUnit(64, 256, 32, 2, 8, 2),
nn.ReLU(),
SKUnit(256, 256, 32, 2, 8, 1),
nn.ReLU(),
SKUnit(256, 256, 32, 2, 2, 1),
nn.ReLU(),
)
self.stage_2 = nn.Sequential(
SKUnit(256, 512, 16, 2, 8, 2),
nn.ReLU(),
SKUnit(512, 512, 16, 2, 8, 1),
nn.ReLU(),
SKUnit(512, 512, 16, 2, 2, 1),
nn.ReLU(),
)
self.stage_3 = nn.Sequential(
SKUnit(512, 1024, 8, 2, 8, 2),
nn.ReLU(),
SKUnit(1024, 1024, 8, 2, 8, 1),
nn.ReLU(),
SKUnit(1024, 1024, 8, 2, 2, 1),
nn.ReLU(),
)

self.pool = nn.AvgPool2d(8)
self.classifier = nn.Linear(1024, class_num)

def forward(self, x):
fea = self.basic_conv(x)
fea = self.stage_1(fea)
fea = self.stage_2(fea)
fea = self.stage_3(fea)
fea = self.pool(fea)
fea = torch.squeeze(fea)
fea = self.classifier(fea)
return fea

Morexmu-xiaoma666/External-Attention-pytorch: 🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐ (github.com)

参考资料

  1. 2111.07624.pdf (arxiv.org) 综述
  2. 2112.05561.pdf (arxiv.org)
  3. 1807.06521.pdf (arxiv.org)
  4. 2103.02907.pdf (arxiv.org)
  5. 1811.11721.pdf (arxiv.org)
  6. arxiv.org/pdf/1711.07971.pdf
  7. 1709.01507.pdf (arxiv.org)
  8. 1903.06586.pdf (arxiv.org)
  9. Convolution-Attention Mechanism Fusion (mlrad.io)

more vision transformer models can be foundlucidrains/vit-pytorch: Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch (github.com)

-------------本文结束感谢您的阅读-------------
感谢阅读.

欢迎关注我的其它发布渠道