YoloV8改进策略:三元注意力,小参数大能力,即插即用,涨点自如

摘要

注意力机制Q在计算机视觉领域得到了广泛的研究和应用,利用构建通道或空间位置之间的依赖关系的能力,有效地应用于各种计算机视觉任务。本文研究了轻量级但有效的注意力机制,并提出了一种新的计算注意力权重的方法一一三元组注意力,通过一个三分支结构捕捉跨维度交互。对于输入张量,三元组注意力通过旋转操作和残差变换建立跨维度的依赖关系,并以极小的计算开销编码了跨通道和空间信息。这种方法既简单又高效,可以轻松地插入经典的主干网络中作为附加模块。在各种具有挑战性的任务中,如lm@geNet-1k图像分类和MSCOCO和PASCAL VOC数据集上的目标检测,证明了该方法的有效性。此外,通过可视化检查Gr阳dCAM和GradCAM-++结果,提供了对三元组注意力性能的深入见解。本文方法的实证评估支持了在计算注意力权重
时捕捉跨维度依赖关系的重要性的直觉。相关代码可以在.上公开访问。https://github.com/LandskapeAl/triplet-attention

三元注意力机制

三元组注意力机制是一种注意力机制,旨在有效地处理跨维度的交互。它由三个分支组成,每个分支负责捕捉输入中空间维度和通道维度之间的跨维度交互特征。具体来说,对于一个输入张量X∈RC×H×W,该机制首先将输入传递给每个分支进行操作。每个分支负责聚合输入中特定维度与通道维度之间的交互特征!

第一分支负责处理输入中空间维度H和W与通道维度C之间的交互特征。它通过在空间维度上应用最大池化和平均池化操作,然后将结果展平并沿着通道维度连接,以获得跨空间维度的交互特征。第二分支负责处理输入中空间维度H和W与通道维度C之间的交互特征。它首先对输入进行全局平均池化操作,然后使用1×1卷积核将结果展平并沿着通道维度连接,以获得跨空间维度的交互特征。第三分支负责处理输入中通道维度C与空间维度H和W之间的交互特征。它首先对输入进行全局最大池化操作,然后使用1×1卷积核将结果展平并沿着通道维度连接,以获得跨通道维度的交互特征。最后,将三个分支的结果连接起来,得到最终的跨维度交互特征。这种机制可以有效地捕捉输入中不同维度之间的交互特征,从而更好地理解图像内容。

三元组注意力机制的优点

 

三元组注意力机制相对于其他注意力机制,如自注意力、多头注意力等,具有以下优势和特点:

1.捕捉三元组信息:三元组注意力机制能够捕捉到三个元素之间的相互作用关系,从而更好地理解输入信息。这种机制可以有效地应用于各种任务,如视觉目标检测、语言翻译、语音识别等。

2.计算效率高:相较于其他注意力机制,三元组注意力机制的计算效率更高。它只关注三个元素之间的相互作用,而不是对整个输入进行计算,从而减少了计算量和时间复杂度。

3.可扩展性强:三元组注意力机制可以方便地扩展到更大的输入尺寸。由于它只关注三个元素之间的相互作用,因此可以在保持较低计算复杂度的同时,对更大的输入进行操作。

4.适用于各种数据类型:三元组注意力机制可以适用于各种数据类型,如图像、文本、音频等。由于它关注的是三个元素之间的相互作用,因此可以广泛应用于各种不同的任务和领域。

5.可解释性强:三元组注意力机制具有更强的可解释性。它可以清楚地解释哪些三元组对输出有影响,从而使得模型更容易理解和调试。

实验结果主要表明了三方面:

1.对比于单一路线注意力机制,tripleti注意力机制在多个标准图像识别数据集上,如ImageNet、Pascal VOC等,都表现出了优越的性能。

2.在一些轻量级的模型,如MobileNetV2上,tripleti注意力机制在保证精度的同时,参数的增加并不多,为约 0.03%。

3.与其他几种注意力机制相比,triplet注意力机制在参数数量上是最少的,且在ImageNet数据集上的top-1错误率降低了0.98%。

 

YOLOv8l summary (fused): 268 layers, 43631280 parameters, 0 gradients, 165.0 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 29/29 [
                   all        230       1412      0.922      0.957      0.986      0.737
                   c17        230        131      0.973      0.992      0.995      0.825
                    c5        230         68      0.945          1      0.995      0.836
            helicopter        230         43       0.96      0.907      0.951      0.607
                  c130        230         85      0.984          1      0.995      0.655
                   f16        230         57      0.955      0.965      0.985      0.669
                    b2        230          2      0.704          1      0.995      0.722
                 other        230         86      0.903      0.942      0.963      0.534
                   b52        230         70       0.96      0.971      0.978      0.831
                  kc10        230         62      0.999      0.984       0.99      0.847
               command        230         40       0.97          1      0.995      0.811
                   f15        230        123      0.891          1      0.992      0.701
                 kc135        230         91      0.971      0.989      0.986      0.712
                   a10        230         27          1      0.555      0.899      0.456
                    b1        230         20      0.972          1      0.995      0.793
                   aew        230         25      0.945          1       0.99      0.784
                   f22        230         17      0.913          1      0.995      0.725
                    p3        230        105       0.99          1      0.995      0.801
                    p8        230          1      0.637          1      0.995      0.597
                   f35        230         32      0.939      0.938      0.978      0.574
                   f18        230        125      0.985      0.992      0.987      0.817
                   v22        230         41      0.983          1      0.995       0.69
                 su-27        230         31      0.925          1      0.995      0.859
                 il-38        230         27      0.972          1      0.995      0.811
                tu-134        230          1      0.663          1      0.995      0.895
                 su-33        230          2          1      0.611      0.995      0.796
                 an-70        230          2      0.766          1      0.995       0.73
                 tu-22        230         98      0.984          1      0.995      0.831
Speed: 0.2ms preprocess, 3.8ms inference, 0.0ms loss, 0.8ms postprocess per image

三元注意力代码

 

 

### For latest triplet_attention module code please refer to the corresponding file in root.

import torch
import torch.nn as nn


class BasicConv(nn.Module):
    def __init__(
        self,
        in_planes,
        out_planes,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        relu=True,
        bn=True,
        bias=False,
    ):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
        self.bn = (
            nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
            if bn
            else None
        )
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat(
            (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
        )


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(
            2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
        )

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out)
        return x * scale


class TripletAttention(nn.Module):
    def __init__(
        self,
        gate_channels,
        reduction_ratio=16,
        pool_types=["avg", "max"],
        no_spatial=False,
    ):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        return x_out

改进一:将三元注意力加入到C2模块中,重构C2模块
改进方法
在ultralytics/nn/modules文件夹下,新建triplet_attention.py脚本,将上面的代a码复制进去,如下图:

然后,在block.py中导入TripletAttention,如下图:

然后,将TripletAttention加入到C2f模块中,代码如下:

class C2f(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
        self.triplet_attention = TripletAttention(c2, 4)

    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.triplet_attention(self.cv2(torch.cat(y, 1)))

    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.triplet_attention(self.cv2(torch.cat(y, 1)))

TripletAttention的第一个参数表示channle,第二个参数表示缩放的倍数,一般设置为4或者16(看了源码才发现,这个其实不用设置,源码中,这两个参数都没有用到)。最后,在项目的根目录添加train.py脚本,代码如下:

 

from ultralytics import YOLO
import os

if __name__ == '__main__':
    # 加载模型
    model = YOLO(model="ultralytics/cfg/models/v8/yolov8l.yaml")  # 从头开始构建新模型

    # Use the model
    results = model.train(data="VOC.yaml", epochs=300, device='0', batch=8, seed=42,)  # 训练模

训练完成后,就可以看到测试结果!在项目的根目录添加val.py脚本,代码如下:


from ultralytics import YOLO


if __name__ == '__main__':
    # Load a model
    # model = YOLO('yolov8m.pt')  # load an official model
    model = YOLO('runs/detect/train/weights/best.pt')  # load a custom model


    # Validate the model
    metrics = model.val(split='val')  # no arguments needed, dataset and settings remembered

split-=va'代表使用验证集做测试,如果改为split=test,则使用测试集做测试!在项目的根目录添加test.py脚本,代码如下:

from ultralytics import YOLO


if __name__ == '__main__':
    # Load a model
    # model = YOLO('yolov8m.pt')  # load an official model
    model = YOLO('runs/detect/train/weights/best.pt')  # load a custom model
    results = model.predict(source="ultralytics/assets", device='0',save=True)  # predict on an image
    print(results)

test脚本测试assets文件夹下面的图片,save设置为true,则保存图片的测试结果!

 

测试结果

                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 15/15 [00:01<00:00,  8.19it/s]
                   all        230       1412      0.969      0.964      0.991      0.745
                   c17        230        131      0.996      0.992      0.995      0.828
                    c5        230         68      0.986          1      0.995      0.817
            helicopter        230         43      0.976      0.966      0.984      0.585
                  c130        230         85          1      0.971      0.994      0.648
                   f16        230         57          1       0.93      0.989      0.681
                    b2        230          2      0.916          1      0.995      0.723
                 other        230         86          1      0.921      0.979       0.56
                   b52        230         70      0.984      0.971      0.981      0.826
                  kc10        230         62      0.997      0.984      0.989      0.837
               command        230         40      0.995          1      0.995      0.821
                   f15        230        123      0.963      0.984      0.993        0.7
                 kc135        230         91      0.982      0.989      0.984      0.705
                   a10        230         27          1      0.522      0.968      0.509
                    b1        230         20      0.998          1      0.995      0.752
                   aew        230         25       0.94          1      0.995      0.789
                   f22        230         17      0.887          1      0.995      0.771
                    p3        230        105          1      0.982      0.995      0.814
                    p8        230          1      0.841          1      0.995      0.796
                   f35        230         32          1       0.83      0.989      0.611
                   f18        230        125      0.984      0.983      0.992      0.829
                   v22        230         41      0.995          1      0.995      0.696
                 su-27        230         31      0.993          1      0.995      0.857
                 il-38        230         27      0.994          1      0.995      0.823
                tu-134        230          1      0.833          1      0.995      0.895
                 su-33        230          2          1          1      0.995      0.697
                 an-70        230          2       0.91          1      0.995      0.726
                 tu-22        230         98      0.999          1      0.995      0.819
Speed: 0.1ms preprocess, 5.0ms inference, 0.0ms loss, 0.6ms postprocess per image

 

改进二:将三元注意力加入到主干网络后面

改进方法
对官方的源码做了修改,代码如下:


class TripletAttention(nn.Module):
    def __init__(
            self,
            gate_channels,
            reduction_ratio=16,
            pool_types=["avg", "max"],
            no_spatial=False,
    ):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        x_out = torch.sigmoid(x_out)
        x = x * x_out
        return x

将返回值改为和输入的矩阵相乘。
将triplet_attention.py复制到ultralytics/nn/modules/文件夹下面,如下图:

 

然后,在_init.py中导入TripletAttention,如下图:

在task.py中导入TripletAttention,如下图:


from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
                                    Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
                                    Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
                                    RTDETRDecoder, Segment,TripletAttention)

在parsemodel函数中,增加TripletAttention模块参数配置的逻辑,如下图:

 

        elif m is TripletAttention:
            args = [ch[f],*args]

修改配置文件,代码如下:

 

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
  - [-1, 1, TripletAttention, [1024, 16]]  # 9
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

最后,在项目的根目录添加train.py脚本,代码如下:

 

from ultralytics import YOLO
import os

if __name__ == '__main__':
    # 加载模型
    model = YOLO(model="ultralytics/cfg/models/v8/yolov8l.yaml")  # 从头开始构建新模型

    # Use the model
    results = model.train(data="VOC.yaml", epochs=300, device='0', batch=8, seed=42,)  # 训练模

训练完成后,就可以看到测试结果!
在项目的根目录添加vl.py脚本,代码如下:

 


from ultralytics import YOLO


if __name__ == '__main__':
    # Load a model
    # model = YOLO('yolov8m.pt')  # load an official model
    model = YOLO('runs/detect/train/weights/best.pt')  # load a custom model


    # Validate the model
    metrics = model.val(split='val')  # no arguments needed, dataset and settings remembered

split-=va'代表使用验证集做测试,如果改为split=test,则使用测试集做测试!在项目的根目录添加test.py脚本,代码如下:

from ultralytics import YOLO


if __name__ == '__main__':
    # Load a model
    # model = YOLO('yolov8m.pt')  # load an official model
    model = YOLO('runs/detect/train/weights/best.pt')  # load a custom model
    results = model.predict(source="ultralytics/assets", device='0',save=True)  # predict on an image
    print(results)

test脚本测试assets.文件夹下面的图片,save设置为true,则保存图片的测试结果!

 

测试结果

                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 15/15 [00:01<00:00,  9.17it/s]
                   all        230       1412      0.956      0.959      0.987      0.748
                   c17        230        131      0.973      0.992      0.995      0.816
                    c5        230         68      0.944      0.988      0.993      0.832
            helicopter        230         43      0.968          1      0.985      0.591
                  c130        230         85      0.988      0.994      0.995      0.663
                   f16        230         57          1      0.939      0.977      0.667
                    b2        230          2      0.887          1      0.995      0.849
                 other        230         86      0.981      0.919      0.973      0.546
                   b52        230         70      0.979      0.971      0.983      0.824
                  kc10        230         62      0.997      0.984      0.988      0.841
               command        230         40      0.993          1      0.995      0.849
                   f15        230        123       0.94      0.984      0.992      0.713
                 kc135        230         91      0.985      0.989      0.987      0.689
                   a10        230         27      0.965      0.667      0.868      0.446
                    b1        230         20      0.972       0.95      0.993      0.669
                   aew        230         25      0.935          1      0.995      0.802
                   f22        230         17      0.935          1      0.995      0.721
                    p3        230        105          1      0.996      0.995      0.811
                    p8        230          1      0.804          1      0.995      0.895
                   f35        230         32       0.98      0.969      0.993      0.567
                   f18        230        125      0.986      0.992      0.992      0.835
                   v22        230         41       0.99          1      0.995      0.671
                 su-27        230         31       0.98          1      0.995       0.86
                 il-38        230         27      0.986          1      0.995      0.825
                tu-134        230          1      0.792          1      0.995      0.895
                 su-33        230          2          1      0.571      0.995      0.697
                 an-70        230          2      0.868          1      0.995      0.796
                 tu-22        230         98      0.997          1      0.995      0.823
Speed: 0.1ms preprocess, 3.7ms inference, 0.0ms loss, 0.6ms postprocess per image

总结本文尝试了使用三元注意力机制改进YooV8,虽然是比较老的论文,但是效果还是可以的。两种改进均有提点,但没有增加太多的运算量。代码和PDF的文章详见:

 

本内容需要 登录 后才能查看

 

 

版权声明:
作者:建模忠哥
链接:http://jianmozhongge.cn/2023/12/03/yolov8%e6%94%b9%e8%bf%9b%e7%ad%96%e7%95%a5%ef%bc%9a%e4%b8%89%e5%85%83%e6%b3%a8%e6%84%8f%e5%8a%9b%ef%bc%8c%e5%b0%8f%e5%8f%82%e6%95%b0%e5%a4%a7%e8%83%bd%e5%8a%9b%ef%bc%8c%e5%8d%b3%e6%8f%92%e5%8d%b3/
来源:建模忠哥
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
海报
YoloV8改进策略:三元注意力,小参数大能力,即插即用,涨点自如
摘要 注意力机制Q在计算机视觉领域得到了广泛的研究和应用,利用构建通道或空间位置之间的依赖关系的能力,有效地应用于各种计算机视觉任务。本文研究了轻量级……
<<上一篇
下一篇>>
文章目录
关闭
目 录