栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

基于深度强化学习的绘画智能体 代码分析(四)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

基于深度强化学习的绘画智能体 代码分析(四)

Github源码

tensorboard.py

import PIL #图像处理库
import scipy.misc #将数组保存成图像形式
from io import BytesIO #在内存中读写bytes
import tensorboardX as tb
from tensorboardX.summary import Summary

class TensorBoard(object):
    def __init__(self, model_dir): #model_dir是下载模型保存地址
        self.summary_writer = tb.FileWriter(model_dir) #指定一个文件用来保存图

    def add_image(self, tag, img, step):
        summary = Summary()
        bio = BytesIO() #创建一个类二进制文件对象

        if type(img) == str:
            img = PIL.Image.open(img)  #返回PIL.Image.Image的类型
        elif type(img) == PIL.Image.Image:
            pass #不需要转换
        else:
            img = PIL.Image.fromarray(img) #array转换成image

        img.save(bio, format="png")
        image_summary = Summary.Image(encoded_image_string=bio.getvalue()) #可视化
        summary.value.add(tag=tag, image=image_summary) #按照标签加入进去
        self.summary_writer.add_summary(summary, global_step=step)  #global_step训练步数

    def add_scalar(self, tag, value, step): #加的是scalar具体的值
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self.summary_writer.add_summary(summary, global_step=step)

Github源码
util.py

import os
import torch
from torch.autograd import Variable

USE_CUDA = torch.cuda.is_available() #你电脑GPU能否PyTorch调用。

def prRed(prt): print("33[91m {}33[00m" .format(prt))
def prGreen(prt): print("33[92m {}33[00m" .format(prt))
def prYellow(prt): print("33[93m {}33[00m" .format(prt))
def prLightPurple(prt): print("33[94m {}33[00m" .format(prt))
def prPurple(prt): print("33[95m {}33[00m" .format(prt))
def prCyan(prt): print("33[96m {}33[00m" .format(prt))
def prLightGray(prt): print("33[97m {}33[00m" .format(prt))
def prBlack(prt): print("33[98m {}33[00m" .format(prt))

def to_numpy(var): #把tensor变成numpy
    return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 
#.data是读取Variable中的tensor   .cpu是把数据转移到cpu    .numpy()把tensor变成numpy

def to_tensor(ndarray, device): #和上面的相反
    return torch.tensor(ndarray, dtype=torch.float, device=device)

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()): #parameters()会返回一个生成器(迭代器),生成器每次生成的是Tensor类型的数据.
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau  #加了tau(0~1),复制一部分
        )

def hard_update(target, source):
    for m1, m2 in zip(target.modules(), source.modules()):
        m1._buffers = m2._buffers.copy()
    for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data) #source.parameters一对一的全部复制到target_param

def get_output_folder(parent_dir, env_name):
    """Return save folder. #返回保存文件夹。
    Assumes folders in the parent_dir have suffix -run{run
    number}. #假定父目录中的文件夹具有后缀-run{run number}
   Finds the highest run number and sets the output folder
    to that number + 1. #查找最高的运行编号,并将输出文件夹设置为该编号+1。
   This is just convenient so that if you run the
    same script multiple times tensorboard can plot all of the results
    on the same plots with different names. #这非常方便,如果您多次运行同一脚本,tensorboard可以使用不同的名称在相同的绘图上绘制所有结果。
    Parameters
    ----------
    parent_dir: str
      Path of the directory containing all experiment runs.
    Returns
    -------
    parent_dir/run_dir
      Path to this run's save directory.
    """
    os.makedirs(parent_dir, exist_ok=True) #创建目录
    experiment_id = 0
    for folder_name in os.listdir(parent_dir):
        if not os.path.isdir(os.path.join(parent_dir, folder_name)):
            continue
        try:
            folder_name = int(folder_name.split('-run')[-1])   #获取文件扩展名
            if folder_name > experiment_id:
                experiment_id = folder_name
        except:
            pass
    experiment_id += 1

    parent_dir = os.path.join(parent_dir, env_name)
    parent_dir = parent_dir + '-run{}'.format(experiment_id)
    os.makedirs(parent_dir, exist_ok=True)
    return parent_dir
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/315834.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号