栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

使用DQN解决cartpole问题(深度强化学习入门)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用DQN解决cartpole问题(深度强化学习入门)

使用DQN解决cartpole问题(深度强化学习入门)
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 11:16:50 2021

@author: wss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # 调用relu啥的
import collections
import random
import torch.optim as optim

#放一些参数
Lr = 0.1   #学习率
Buffer_size = 10000 #经验回放的buffer的大小
Eps = 0.1   # eps 贪心算法的随机选择比列
GAMMA = 0.99  # reward的衰减



#用队列存 transition     并定义了采样函数
Transition = collections.namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
# 用一个类来实现经验回放 去除state的相关性和利用经验
class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = collections.deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)




#定义DQN 的神经网络部分

class Net(nn.Module):
    def __init__(self,n_in,n_hidden,n_out):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(n_in, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, n_out)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        out = self.fc3(x)
        return out

    
# net属性是神经网络的对象
class DQN(object):
    def __init__(self,n_in,n_hidden,n_out):
#        super(DQN,self).__init__()
        self.net = Net(n_in,n_hidden,n_out)
        self.target_net = Net(n_in,n_hidden,n_out)
        self.optimer = optim.Adam(self.net.parameters(),lr = Lr)
        self.loss_func = nn.MSELoss()
        self.target_net.load_state_dict(self.net.state_dict())
#        self.target_net.eval()     # 解决高估问题  不用训练直接加载policy_net的参数
        self.buffer = ReplayMemory(Buffer_size)
        
        
    #根据state选择 action 
    def select_action(self,state): #返回的action是个数字(不是张量)
        threshold = random.random() 
        Q_actions = self.net(torch.Tensor(state)) #返回不同action对应的Q值
        if  threshold 
刚刚接触深度学习以及强化学习,不知道为什么这个DQN并没有随着训练越来越来越好?
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/589743.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号