栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

深度强化学习 学术前沿与实战应用——PPO

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

深度强化学习 学术前沿与实战应用——PPO

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import random
import gym
import time
# PPO actor-critic模型
class Model(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Model, self).__init__()
        h_size_1 = 100
        h_size_2 = 100
self.v_fc1 = nn.Linear(num_inputs, h_size_1*5)
        self.v_fc2 = nn.Linear(h_size_1*5, h_size_2)
        self.v = nn.Linear(h_size_2, 1)
self.p_fc1 = nn.Linear(num_inputs, h_size_1)
        self.p_fc2 = nn.Linear(h_size_1, h_size_2)
        self.mu = nn.Linear(h_size_2, num_outputs)
        self.log_std = nn.Parameter(torch.zeros(1, num_outputs))
        for name, p in self.named_parameters():
            # init parameters
            if 'bias' in name:
                p.data.fill_(0)
        self.train()
    def forward(self, inputs):
        # actor
        x = F.tanh(self.p_fc1(inputs))
        x = F.tanh(self.p_fc2(x))
        mu = self.mu(x)
        sigma_sq = torch.exp(self.log_std)
        # critic
        x = F.tanh(self.v_fc1(inputs))
        x = F.tanh(self.v_fc2(x))
        v = self.v(x)
        return mu, sigma_sq, v
# 定义共享梯度区类
class Shared_grad_buffers():
    def __init__(self, model):
        self.grads = {}
        for name, p in model.named_parameters():
            self.grads[name+'_grad'] = torch.ones(p.sizes()).share_memory_()
    def add_gradient(self, model):
        for name, p in model.named_parameters():
            self.grads[name+'_grad'] += p.grad.data
    def reset(self):
        for name, grad in self.grads.items():
            self.grads[name].fill_(0)
# 定义状态的规范化
class Shared_obs_stats():
    def __init__(self, num_inputs):
        self.n = torch.zeros(num_inputs).share_memory_()
        self.mean = torch.zeros(num_inputs).share_memory_()
        self.mean_diff = torch.zeros(num_inputs).share_memory_()
        self.var = torch.zeros(num_inputs).share_memory_()

 

    def observes(self, obs):
        # observation mean var updates
        x = obs.data.squeeze()
        self.n += 1
        last_mean = self.mean.clone()
        self.mean += (x-self.mean)/self.n
        self.mean_diff += (x-last_mean)*(x-self.mean)
        self.var = torch.clamp(self.mean_diff/self.n, min=1e-2)

 

    def normalize(self, inputs):
        obs_mean = Variable(self.mean.unsqueeze(0).expand_as(inputs))
        obs_std = Variable(torch.sqrt(self.var).unsqueeze(0).expand_as(inputs))
        return torch.clamp((inputs-obs_mean)/obs_std, -5., 5.)

 

# 经验复用类
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
 def push(self, events):
        for event in zip(*events):
            self.memory.append(event)
            if len(self.memory) > self.capacity:
                del self.memory[0]

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/283589.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号