栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

训练semantic segmentation时的报错

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

训练semantic segmentation时的报错

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1,3,840,840]

        以上是在训练semantic segmentation时,出现的报错。翻了一圈看到的唯一靠谱的解释。

        原因大概就是label图片需要1维的数据格式(灰度图),但是图片在输入前仍是3维的RGB图片,没转换成1维的。以下是靠谱解释的链接

dcRuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 3, 96, 128] - vision - PyTorch ForumsI am using UNet I have images with masks (black background and highlighted portion contain different RGB colors) and 12 classes. I want to do training using uNET. but i got error. “RuntimeError: 1only batches of spatial…https://discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030Training Semantic Segmentation - #3 by WeiQin_Chuah - vision - PyTorch ForumsHi, I am trying to reproduce PSPNet using PyTorch and this is my first time creating a semantic segmentation model. I understand that for image classification model, we have RGB input = [h,w,3] and label or ground truth…https://discuss.pytorch.org/t/training-semantic-segmentation/49275/3

/tmp/pip-req-build-xlj_h8ax/aten/src/THCUNN/SpatialClassNLLCriterion.cu:106: cunn_SpatialClassNLLCriterion_updateOutput_kernel: block: [4,0,0], thread: [417,0,0] Assertion `t >= 0 && t < n_classes` failed.

RuntimeError: CUDA error: device-side assert triggered

        以上报错是因为训练时候输入的是4类,但是label数据因为用cv2.resize时,采用了默认的cv2.INTER_LINER插值法,插入了其他非类别的数值,改成cv2.INTER_NEAREST最近邻插值法即可。转换语句如下:

cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST)

# 不可简写成cv2.resize(image, (tw, th), cv2.INTER_NEAREST),否则仍是默认的cv2.INTER_LINER

参考:

Is there anybody happen this error? - #14 by XiaoAHeng - autograd - PyTorch Forums/opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THCUNN/SpatialClassNLLCriterion.cu:99: void cunn_SpatialClassNLLCriterion_updateOutput_kernel(T *, T *, T *, long *, T *, int, int, int, int, int, long) [with T =…https://discuss.pytorch.org/t/is-there-anybody-happen-this-error/17416/14        ​​​​​​​由于PIL.Image 和cv2读取图片后数据格式不同,(Image读取后的size是宽,高;cv2读取后的shape是高,宽),我就手残把处理图片的代码全改了(其实完全不用),以下是使用差别较大的几个方法:

# 截取图片
label = label[h:h+h, w:w+w]  # cv2
label = label.crop((w, h, w + w, h + h))  # PIL

# 左右翻转
cv2.flip(image, 1)  # cv2
image.transpose(0)  # PIL

# 修改图片尺寸,特别要注意interpolation的对应关系
cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST)  # cv2
image.resize((tw, th), interpolation)  # PIL

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/272594.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号