1.生成anchor
def _meshgrid(self, x, y):
"""Generate mesh grid of x and y
Args:
x (torch.Tensor): Grids of x dimension.
y (torch.Tensor): Grids of y dimension.
row_major (bool, optional): Whether to return y grids first.
Defaults to True.
Returns:
tuple[torch.Tensor]: The mesh grids of x and y.
"""
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
return xx, yy
# 以第一层特征图(200,200)为例:
feat_h, feat_w = featmap_size # feat_h=int(200)
shift_x = torch.arange(0, feat_w, device=device) * stride[0] # stride = tuple(int(4),int(4))
shift_y = torch.arange(0, feat_h, device=device) * stride[1]
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4) # (200*200*3,4)
# base_anchors : tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
[-16.0000, -16.0000, 16.0000, 16.0000],
[-11.3137, -22.6274, 11.3137, 22.6274]], device='cuda:0')
2.解码过程(输入为anchor和预测6参数,返回8参数)
decoded_bboxes = delta_sp2bbox(bboxes, pred_bboxes, self.means, self.stds,
wh_ratio_clip)
# bboxes即anchor(([8768, 4])) pred_bboxes([8768, 6]) decoded_bboxes([8768, 1,8])
# self.means = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
# self.stds = [1.0, 1.0, 1.0, 1.0, 0.5, 0.5] wh_ratio_clip=16 / 1000
def delta_sp2bbox(rois, deltas,
means=(0., 0., 0., 0., 0., 0.),
stds=(1., 1., 1., 1., 1., 1.),
wh_ratio_clip=16 / 1000):
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 6)
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 6)
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0::6]
dy = denorm_deltas[:, 1::6]
dw = denorm_deltas[:, 2::6]
dh = denorm_deltas[:, 3::6]
da = denorm_deltas[:, 4::6]
db = denorm_deltas[:, 5::6]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = dw.clamp(min=-max_ratio, max=max_ratio)
dh = dh.clamp(min=-max_ratio, max=max_ratio)
# Compute center of each roi
px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
# Compute width/height of each roi
pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
# Use exp(network energy) to enlarge/shrink each roi
gw = pw * dw.exp()
gh = ph * dh.exp()
# Use network energy to shift the center of each roi
gx = px + pw * dx
gy = py + ph * dy
x1 = gx - gw * 0.5
y1 = gy - gh * 0.5
x2 = gx + gw * 0.5
y2 = gy + gh * 0.5
da = da.clamp(min=-0.5, max=0.5)
db = db.clamp(min=-0.5, max=0.5)
ga = gx + da * gw
_ga = gx - da * gw
gb = gy + db * gh
_gb = gy - db * gh
polys = torch.stack([ga, y1, x2, gb, _ga, y2, x1, _gb], dim=-1)
center = torch.stack([gx, gy, gx, gy, gx, gy, gx, gy], dim=-1)
center_polys = polys - center
diag_len = torch.sqrt(
torch.square(center_polys[..., 0::2]) + torch.square(center_polys[..., 1::2]))
max_diag_len, _ = torch.max(diag_len, dim=-1, keepdim=True)
diag_scale_factor = max_diag_len / diag_len
center_polys = center_polys * diag_scale_factor.repeat_interleave(2, dim=-1)
rectpolys = center_polys + center
return obboxes
3。 8参数转5参数,5参数转水平框
five = rectpoly2obb(eight)
def rectpoly2obb(polys):
theta = torch.atan2(-(polys[..., 3] - polys[..., 1]),
polys[..., 2] - polys[..., 0])
Cos, Sin = torch.cos(theta), torch.sin(theta)
Matrix = torch.stack([Cos, -Sin, Sin, Cos], dim=-1)
Matrix = Matrix.view(*Matrix.shape[:-1], 2, 2)
x = polys[..., 0::2].mean(-1)
y = polys[..., 1::2].mean(-1)
center = torch.stack([x, y], dim=-1).unsqueeze(-2)
center_polys = polys.view(*polys.shape[:-1], 4, 2) - center
rotate_polys = torch.matmul(center_polys, Matrix.transpose(-1, -2))
xmin, _ = torch.min(rotate_polys[..., :, 0], dim=-1)
xmax, _ = torch.max(rotate_polys[..., :, 0], dim=-1)
ymin, _ = torch.min(rotate_polys[..., :, 1], dim=-1)
ymax, _ = torch.max(rotate_polys[..., :, 1], dim=-1)
w = xmax - xmin
h = ymax - ymin
obboxes = torch.stack([x, y, w, h, theta], dim=-1)
return regular_obb(obboxes)
def obb2hbb(obboxes):
center, w, h, theta = torch.split(obboxes, [2, 1, 1, 1], dim=-1)
Cos, Sin = torch.cos(theta), torch.sin(theta)
x_bias = torch.abs(w/2 * Cos) + torch.abs(h/2 * Sin)
y_bias = torch.abs(w/2 * Sin) + torch.abs(h/2 * Cos)
bias = torch.cat([x_bias, y_bias], dim=-1)
return torch.cat([center-bias, center+bias], dim=-1)