简介:Swin Transformer做主干的 Faster RCNN 目标检测网络。
 
一、所需软件(包)介绍
- 项目工程:mmdetection,直接去github拉取代码即可,拉取位置:mmdetection ,确保当前mmdetection版本支持mmcv 1.3.17,因为后面使用的环境是mmcv 1.3.17的,mmdet与mmcv版本对应关系参考:mmdet与mmcv版本 ,如果未来master支持的mmcv版本要求大于1.3.17的话,请按照要求安装对应的版本。
 - 开发环境:与之前 Swin Transformer Object Detection工程所使用的环境相同,安装过程参考:Swin Transformer Object Detection 目标检测-1——环境搭建详细教程
 
二、环境搭建
- 如果之前已经创建了 Swin Transformer Object Detection 项目所需的环境的话,可以直接使用,但是会对后面再训练Swin Transformer Object Detection 造成影响(因为mmdetection工程需要对mmdet的版本进行更改才能使用),所以建议再创建一个新的环境给mmdetection使用,或者直接clone一份之前的环境(推荐)。
 - 克隆环境的方式为:
conda create -n conda-env2 --clone conda-env1- conda-env2 为新创建的环境(从其他环境clone来的)
 - conda-env1 为之前已经有的环境
 
 
注:克隆环境需要一段时间,请耐心等待。这样后面我们mmdetection的工程所使用的环境就是新clone的这个。clone 成功后按照下面步骤操作:
- 在IDE中配置项目所使用的虚拟环境为我们新克隆的
 - 进如到虚拟环境后,在mmdetection的项目目录下执行
python setup.py develop,此时确定 mmdet被换成 2.23.0版本。 
三、Swin Transformer Faster RCNN 网络结构图
Swin Transformer Faster RCNN 没看到什么官方的名字,索性就这么叫吧。实际上就是Swin Transformer 作为Faster RCNN网络的Backbone(主干特征提取网络)。
四、Swin Transformer Faster RCNN 网络代码
1. 在configs/swin 目录下新建文件:faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
文件内容如下:
注意:训练的epoch在这个文件中改,我直接设置成了50,大家根据需要修改。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20_base_ = [
    '../_base_/models/faster_rcnn_swin_fpn.py',
    '../_base_/datasets/faster_rcnn_coco_instance.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
2. 在 configs/base/models 下新建文件:faster_rcnn_swin_fpn.py
文件内容如下:
注意: num_classes 需要根据你数据集的类别进行更改1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116# model settings
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'
model = dict(
    type='FasterRCNN',
    backbone=dict(
        type='SwinTransformer',
        embed_dims=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=4,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=False,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100)
        # soft-nms is also supported for rcnn testing
        # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
    ))
3. 在/base/datasets 目录下新建文件:faster_rcnn_coco_instance.py
文件内容如下:
注意: 
- img_scale、samples_per_gpu、 workers_per_gpu可以根据自己的显存大小适当调大、调小
 - 数据集配置部分参考B站教程:数据集标注
 
1  | # dataset settings  | 
4. 修改mmdet/datasets/ 下 coco.py
CLASSES中填写自己的分类:例如 CLASSES = ('person', 'bicycle', 'car')。
当只有一个类别时,多加一个逗号:CLASSES = ('person',)
五、数据集
数据集依然使用默认的coco格式,数据集制作参考数据集标注(LabelImg、LabelMe使用方法)
注:其实这里是可以使用voc格式的,先挖个坑,后面补上。
六、训练模型
直接执行: python tools/train.py configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
注意:第一次执行会下载权值文件,需要等待一段时间,或者用特殊办法快点下载,权值文件会自动保存到你的电脑上,下次运行的时候就不再需要重新下载了,当然也可以和之前一样,提前下载好权值文件,然后配置一下。
七、测试训练效果
添加一个自己的图片在demo目录下,
执行:python demo/image_demo.py demo/000071.jpg configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py work_dirs/faster_rcnn_swin_t-p4-w7_fpn_3x_coco/latest.pth
latest.pth 就是自己训练好的最新的权重文件,默认会放在workdir下。
Q & A
Q1. 报错:ImportError: cannot import name ‘init_random_seed’ from ‘mmdet.apis’
A1:进如到虚拟环境后,在mmdetection的项目目录下执行python setup.py develop ,此时 mmdet被换成 2.23.0版本。

...
...
本文为作者原创文章,未经作者允许不得转载。