简介:Swin Transformer做主干的 Faster RCNN 目标检测网络。
一、所需软件(包)介绍
- 项目工程:mmdetection,直接去github拉取代码即可,拉取位置:mmdetection ,确保当前mmdetection版本支持mmcv 1.3.17,因为后面使用的环境是mmcv 1.3.17的,mmdet与mmcv版本对应关系参考:mmdet与mmcv版本 ,如果未来master支持的mmcv版本要求大于1.3.17的话,请按照要求安装对应的版本。
- 开发环境:与之前 Swin Transformer Object Detection工程所使用的环境相同,安装过程参考:Swin Transformer Object Detection 目标检测-1——环境搭建详细教程
二、环境搭建
- 如果之前已经创建了 Swin Transformer Object Detection 项目所需的环境的话,可以直接使用,但是会对后面再训练Swin Transformer Object Detection 造成影响(因为mmdetection工程需要对mmdet的版本进行更改才能使用),所以建议再创建一个新的环境给mmdetection使用,或者直接clone一份之前的环境(推荐)。
- 克隆环境的方式为:
conda create -n conda-env2 --clone conda-env1
- conda-env2 为新创建的环境(从其他环境clone来的)
- conda-env1 为之前已经有的环境
注:克隆环境需要一段时间,请耐心等待。这样后面我们mmdetection的工程所使用的环境就是新clone的这个。clone 成功后按照下面步骤操作:
- 在IDE中配置项目所使用的虚拟环境为我们新克隆的
- 进如到虚拟环境后,在mmdetection的项目目录下执行
python setup.py develop
,此时确定 mmdet被换成 2.23.0版本。
三、Swin Transformer Faster RCNN 网络结构图
Swin Transformer Faster RCNN 没看到什么官方的名字,索性就这么叫吧。实际上就是Swin Transformer 作为Faster RCNN网络的Backbone(主干特征提取网络)。
四、Swin Transformer Faster RCNN 网络代码
1. 在configs/swin 目录下新建文件:faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
文件内容如下:
注意:训练的epoch在这个文件中改,我直接设置成了50,大家根据需要修改。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20_base_ = [
'../_base_/models/faster_rcnn_swin_fpn.py',
'../_base_/datasets/faster_rcnn_coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
2. 在 configs/base/models 下新建文件:faster_rcnn_swin_fpn.py
文件内容如下:
注意: num_classes 需要根据你数据集的类别进行更改1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116# model settings
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'
model = dict(
type='FasterRCNN',
backbone=dict(
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='FPN',
in_channels=[96, 192, 384, 768],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
scales=[8],
ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
roi_head=dict(
type='StandardRoIHead',
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='Shared2FCBBoxHead',
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=4,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
# model training and testing settings
train_cfg=dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
match_low_quality=True,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=-1,
pos_weight=-1,
debug=False),
rpn_proposal=dict(
nms_pre=2000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
match_low_quality=False,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False)),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
max_per_img=1000,
nms=dict(type='nms', iou_threshold=0.7),
min_bbox_size=0),
rcnn=dict(
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
))
3. 在/base/datasets 目录下新建文件:faster_rcnn_coco_instance.py
文件内容如下:
注意:
- img_scale、samples_per_gpu、 workers_per_gpu可以根据自己的显存大小适当调大、调小
- 数据集配置部分参考B站教程:数据集标注
1 | # dataset settings |
4. 修改mmdet/datasets/ 下 coco.py
CLASSES中填写自己的分类:例如 CLASSES = ('person', 'bicycle', 'car')
。
当只有一个类别时,多加一个逗号:CLASSES = ('person',)
五、数据集
数据集依然使用默认的coco格式,数据集制作参考数据集标注(LabelImg、LabelMe使用方法)
注:其实这里是可以使用voc格式的,先挖个坑,后面补上。
六、训练模型
直接执行: python tools/train.py configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
注意:第一次执行会下载权值文件,需要等待一段时间,或者用特殊办法快点下载,权值文件会自动保存到你的电脑上,下次运行的时候就不再需要重新下载了,当然也可以和之前一样,提前下载好权值文件,然后配置一下。
七、测试训练效果
添加一个自己的图片在demo目录下,
执行:python demo/image_demo.py demo/000071.jpg configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py work_dirs/faster_rcnn_swin_t-p4-w7_fpn_3x_coco/latest.pth
latest.pth 就是自己训练好的最新的权重文件,默认会放在workdir下。
Q & A
Q1. 报错:ImportError: cannot import name ‘init_random_seed’ from ‘mmdet.apis’
A1:进如到虚拟环境后,在mmdetection的项目目录下执行python setup.py develop
,此时 mmdet被换成 2.23.0版本。
...
...
本文为作者原创文章,未经作者允许不得转载。