栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch使用过程错误与解决 -汇总~

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch使用过程错误与解决 -汇总~

Pytorch使用过程错误与解决
  • error1:关键词 copy tensor
  • error2:关键词 张量相加
  • error3:关键词 nn.Linear()的使用
    • 报错1:
      • 报错代码:
      • 错误原因:
    • 报错2:
      • 报错代码:
      • 错误原因:
    • 解决办法
      • 错误原因:
      • 正确代码

error1:关键词 copy tensor

报错信息:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

解决办法1:
*** 当转换某个变量为tensor时,尽量使用torch.as_tensor()***

原错误代码:
state_tensor_list = [torch.tensor(i) for i in batch.state]

修改为:
state_tensor_list = [torch.as_tensor(i) for i in batch.state]

解决办法2:
*** 当转换某个变量x为tensor时,尽量使用x.clone().detach() or x.clone().detach().requires_grad_(True) ***

原错误代码:
state_tensor_list = [torch.tensor(i) for i in batch.state]

修改为:
state_tensor_list = [i.clone().detach() for i in batch.state]
error2:关键词 张量相加

参考链接:

报错信息:
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
解决办法:
检查张量维度!!!

错误代码
next_state_values = torch.tensor([0.7056, 0.7165, 0.6326])
state_action_values=torch.tensor([[ 0.1139,  0.1139,  0.1139,  0.1139],
        [ 0.0884,  0.0884,  0.0884,  0.0884],
        [ 0.0019,  0.0019,  0.0019,  0.0019]])
print(next_state_values.shape)
print(state_action_values.shape)
print(next_state_values.size())
print(state_action_values.size())
next_state_values + state_action_values

结果:
torch.Size([3])
torch.Size([3, 4])
torch.Size([3])
torch.Size([3, 4])
修改代码:
next_state_values = torch.tensor([0.7056, 0.7165, 0.6326])
state_action_values=torch.tensor([[ 0.1139,  0.1139,  0.1139,  0.1139],
        [ 0.0884,  0.0884,  0.0884,  0.0884],
        [ 0.0019,  0.0019,  0.0019,  0.0019]]).max(1)[0]
print(next_state_values.shape)
print(state_action_values.shape)
print(next_state_values.size())
print(state_action_values.size())
next_state_values + state_action_values

结果:
torch.Size([3])
torch.Size([3])
torch.Size([3])
torch.Size([3])
tensor([0.8195, 0.8049, 0.6345])

error3:关键词 nn.Linear()的使用

参考链接:

报错1:

报错:RuntimeError: expected scalar type Long but found Float

报错代码:
import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self,state_dim,mid_dim,action_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(state_dim, mid_dim))
    def forward(self,state):
        res = self.net(state)
        return res
net = Net(9,5,4)
print(net)
current_state = torch.tensor([0,0,0,0,0,0,0,0,0])
print(current_state.shape)
action = net(current_state)
print(action)

结果:
Net(
  (net): Sequential(
    (0): Linear(in_features=9, out_features=5, bias=True)
  )
)
torch.Size([9])
torch.Size([9])
错误原因:

将一维张量转换成二维张量后才能输入

报错2:

报错:RuntimeError: expected scalar type Float but found Long

报错代码:
import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self,state_dim,mid_dim,action_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(state_dim, mid_dim))
    def forward(self,state):
        res = self.net(state)
        return res
net = Net(9,5,4)
print(net)
current_state = torch.tensor([0,0,0,0,0,0,0,0,0])
print(current_state.shape)
current_state = current_state.view(1,9)
print(current_state.shape)
action = net(current_state)
print(action)
结果:
Net(
  (net): S
  equential(
    (0): Linear(in_features=9, out_features=5, bias=True)
  )
)
torch.Size([9])
torch.Size([1, 9])
错误原因:

传入的张量数据类型应该为float32

解决办法 错误原因:

1、将张量转为二维使用.view(1,shape)
2、指定张量数据类型 dtype=torch.float32

正确代码
import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self,state_dim,mid_dim,action_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(state_dim, mid_dim))
    def forward(self,state):
        res = self.net(state)
        return res
net = Net(9,5,4)
print(net)
current_state = torch.tensor([0,0,0,0,0,0,0,0,0],**dtype=torch.float32**)
print(current_state.shape)
**current_state = current_state.view(1,9)**
print(current_state.shape)
action = net(current_state)
print(action)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/619321.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号