栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

复现PointPillar目标检测网络里的PointPillarScatter

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

复现PointPillar目标检测网络里的PointPillarScatter

背景知识

PointPillar 3D点云目标检测模型,提出时间比较久了,模型不做过多介绍,给个参考链接,可自行了解。PointPillar:利用伪图像高效实现3D目标检测 - 云+社区 - 腾讯云

在网络模型里,使用pointnet提取点的特征后,会将pillars(C P)映射成pseudo images(C H W),这里介绍使用pytorch的tensor.scatter函数实现此操作。

模型输入有两个,一个是point(D P N),一个是point对应的indices(P 2)。point经过pointnet后变成(C P)。indices里的2分别是点云网格代后的x,y坐标,按 y*width + x 展开成一维,变成(P),将其重复C份,可变成(C P),与pointnet的输出保持一致。

输出是pseudo images(C H W),其中H * W > P。

下面是实现代码,添加了batch size(N)。

import torch
from torch import nn


class PointPillarScatter(nn.Module):
    def __init__(self, input_shape, indices):
        super(PointPillarScatter, self).__init__()
        self.input_shape = input_shape # H W
        self.indices = indices # N C P

    def forward(self, x):
        # x: N C P
        n = x.shape[0]
        c = x.shape[1]
        h = self.input_shape[0]
        w = self.input_shape[1]
        return torch.zeros((n, c, h * w), dtype=x.dtype).scatter(2, self.indices, x).reshape((n, c, h, w))

测试代码

def test_pps():
    input_shape = (2, 2) # H W
    num_channel = 2 # C
    indices = torch.tensor([[1, 2]], dtype=torch.int64) # N P
    indices = indices.unsqueeze(dim=1) # N 1 P
    indices = indices.repeat_interleave(num_channel, dim=1) # N C P
    x = torch.tensor([[[3.0, 4.0], [1.0, 2.0]]], dtype=torch.float16) # N C P
    print(indices.shape)
    print(x.shape)
    # 将target里的indices所指示位置的值用x里的值替换
    # 通道0, 在替换的维度上,target: 0 ... h * w,indices里标记的位置是1, 2,对应的值是x里的[3, 4]
    pps = PointPillarScatter(input_shape, indices)
    x = pps(x)
    print(x.shape)
    print(x)


if __name__ == "__main__":
    test_pps()

torch.Size([1, 2, 2])
torch.Size([1, 2, 2])
torch.Size([1, 2, 2, 2])
tensor([[[[0., 3.],
          [4., 0.]],

         [[0., 1.],
          [2., 0.]]]], dtype=torch.float16)
 

参考文章:

【Pytorch】scatter函数详解_guofei_fly的博客-CSDN博客_pytorch scatter 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/700293.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号