-
Notifications
You must be signed in to change notification settings - Fork 512
/
ssd_heads.py
240 lines (210 loc) · 8.11 KB
/
ssd_heads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torchvision.ops.roi_align import RoIAlign
from corenet.modeling.layers import ConvLayer2d, SeparableConv2d, TransposeConvLayer2d
from corenet.modeling.misc.init_utils import initialize_conv_layer
from corenet.modeling.modules import BaseModule
class SSDHead(BaseModule):
"""
This class defines the `SSD object detection Head <https://arxiv.org/abs/1512.02325>`_
Args:
opts: command-line arguments
in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
n_anchors (int): Number of anchors
n_classes (int): Number of classes in the dataset
n_coordinates (Optional[int]): Number of coordinates. Default: 4 (x, y, w, h)
proj_channels (Optional[int]): Number of projected channels. If `-1`, then projection layer is not used
kernel_size (Optional[int]): Kernel size in convolutional layer. If kernel_size=1, then standard
point-wise convolution is used. Otherwise, separable convolution is used
stride (Optional[int]): stride for feature map. If stride > 1, then feature map is sampled at this rate
and predictions are made on fewer pixels as compared to the input tensor. Default: 1
"""
def __init__(
self,
opts,
in_channels: int,
n_anchors: int,
n_classes: int,
n_coordinates: Optional[int] = 4,
proj_channels: Optional[int] = -1,
kernel_size: Optional[int] = 3,
stride: Optional[int] = 1,
*args,
**kwargs
) -> None:
super().__init__()
proj_layer = None
self.proj_channels = None
if proj_channels != -1 and proj_channels != in_channels and kernel_size > 1:
proj_layer = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=proj_channels,
kernel_size=1,
stride=1,
groups=1,
bias=False,
use_norm=True,
use_act=True,
)
in_channels = proj_channels
self.proj_channels = proj_channels
self.proj_layer = proj_layer
conv_fn = ConvLayer2d if kernel_size == 1 else SeparableConv2d
if kernel_size > 1 and stride > 1:
kernel_size = max(kernel_size, stride if stride % 2 != 0 else stride + 1)
self.loc_cls_layer = conv_fn(
opts=opts,
in_channels=in_channels,
out_channels=n_anchors * (n_coordinates + n_classes),
kernel_size=kernel_size,
stride=1,
groups=1,
bias=True,
use_norm=False,
use_act=False,
)
self.n_coordinates = n_coordinates
self.n_classes = n_classes
self.n_anchors = n_anchors
self.k_size = kernel_size
self.stride = stride
self.in_channel = in_channels
self.reset_parameters()
def __repr__(self) -> str:
repr_str = "{}(in_channels={}, n_anchors={}, n_classes={}, n_coordinates={}, kernel_size={}, stride={}".format(
self.__class__.__name__,
self.in_channel,
self.n_anchors,
self.n_classes,
self.n_coordinates,
self.k_size,
self.stride,
)
if self.proj_layer is not None:
repr_str += ", proj=True, proj_channels={}".format(self.proj_channels)
repr_str += ")"
return repr_str
def reset_parameters(self) -> None:
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
initialize_conv_layer(module=layer, init_method="xavier_uniform")
def _sample_fm(self, x: Tensor) -> Tensor:
height, width = x.shape[-2:]
device = x.device
start_step = max(0, self.stride // 2)
indices_h = torch.arange(
start=start_step,
end=height,
step=self.stride,
dtype=torch.int64,
device=device,
)
indices_w = torch.arange(
start=start_step,
end=width,
step=self.stride,
dtype=torch.int64,
device=device,
)
x_sampled = torch.index_select(x, dim=-1, index=indices_w)
x_sampled = torch.index_select(x_sampled, dim=-2, index=indices_h)
return x_sampled
def forward(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
batch_size = x.shape[0]
if self.proj_layer is not None:
x = self.proj_layer(x)
# [B x C x H x W] --> [B x Anchors * (coordinates + classes) x H x W]
x = self.loc_cls_layer(x)
if self.stride > 1:
x = self._sample_fm(x)
# [B x Anchors * (coordinates + classes) x H x W] --> [B x H x W x Anchors * (coordinates + classes)]
x = x.permute(0, 2, 3, 1)
# [B x H x W x Anchors * (coordinates + classes)] --> [B x H*W*Anchors X (coordinates + classes)]
x = x.contiguous().view(batch_size, -1, self.n_coordinates + self.n_classes)
# [B x H*W*Anchors X (coordinates + classes)] --> [B x H*W*Anchors X coordinates], [B x H*W*Anchors X classes]
box_locations, box_classes = torch.split(
x, [self.n_coordinates, self.n_classes], dim=-1
)
return box_locations, box_classes
class SSDInstanceHead(BaseModule):
"""
Instance segmentation head for SSD model.
"""
def __init__(
self,
opts,
in_channels: int,
n_classes: Optional[int] = 1,
inner_dim: Optional[int] = 256,
output_stride: Optional[int] = 1,
output_size: Optional[int] = 8,
*args,
**kwargs
) -> None:
"""
Args:
opts: command-line arguments
in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
n_classes (Optional[int]): Number of classes. Default: 1
inner_dim: (Optional[int]): Inner dimension of the instance head. Default: 256
output_stride (Optional[int]): Output stride of the feature map. Output stride is the ratio of input to
the feature map size. Default: 1
output_size (Optional[int]): Output size of the instances extracted from RoIAlign layer. Default: 8
"""
super().__init__()
self.roi_align = RoIAlign(
output_size=output_size,
spatial_scale=1.0 / output_stride,
sampling_ratio=2,
aligned=True,
)
self.seg_head = nn.Sequential(
TransposeConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=inner_dim,
kernel_size=2,
stride=2,
bias=True,
use_norm=False,
use_act=True,
auto_padding=False,
padding=0,
output_padding=0,
),
ConvLayer2d(
opts=opts,
in_channels=inner_dim,
out_channels=n_classes,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
bias=True,
),
)
self.inner_channels = inner_dim
self.in_channels = in_channels
self.mask_classes = n_classes
self.reset_parameters()
def __repr__(self) -> str:
return "{}(in_channels={}, up_out_channels={}, n_classes={})".format(
self.__class__.__name__,
self.in_channels,
self.inner_channels,
self.mask_classes,
)
def reset_parameters(self) -> None:
for layer in self.modules():
if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
initialize_conv_layer(module=layer, init_method="kaiming_normal")
def forward(self, x: Tensor, boxes: Tensor, *args, **kwargs) -> Tensor:
rois = self.roi_align(x, boxes)
rois = self.seg_head(rois)
return rois