栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

解决calaTrain.py 的训练模型无法在run

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

解决calaTrain.py 的训练模型无法在run

1.  命名空间的问题。对图中的  < with tf.variable_scope('controlNET') as scope: >  进行注释。

         打印张量所在的名称空间代码如下

import os

from tensorflow.python import pywrap_tensorflow

# current_path = os.getcwd()
# model_dir = os.path.join(current_path, 'model.ckpt')
model_dir = '/home/binghong/documents/PycharmProjects_backup/Immitation_Learning/carlaILTrainer-master/test/'
checkpoint_path = os.path.join(model_dir,'model.ckpt') # 保存的ckpt文件名,不一定是这个
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)

2. 注意data.close()的位置

def genData(fileNames = datasetFilesTrain, batchSize = 200):
    #fileNames = datasetFilesTrain
    #branchNum = 3 # Control signal, int ( 2 Follow lane, 3 Left, 4 Right, 5 Straight)
    #batchSize = 200
    batchX = np.zeros((batchSize, 88, 200, 3))
    batchY = np.zeros((batchSize, 28))
    idx = 0
    while True: # to make sure we never reach the end
        counter = 0
        while counter<=batchSize-1:
            idx = np.random.randint(len(fileNames)-1)
            try:
                data = h5py.File(fileNames[idx], 'r')
            except:
                print(idx, fileNames[idx])

            dataIdx = np.random.randint(200-1)
            batchX[counter] = data['rgb'][dataIdx]
            batchY[counter] = data['targets'][dataIdx]
            counter += 1
        #data.close()
        yield (batchX, batchY)
    data.close()

def genBranch(fileNames = datasetFilesTrain, branchNum = 3, batchSize = 200):
    #fileNames = datasetFilesTrain
    #branchNum = 3 # Control signal, int ( 2 Follow lane, 3 Left, 4 Right, 5 Straight)
    #batchSize = 200
    batchX = np.zeros((batchSize, 88, 200, 3))
    batchY = np.zeros((batchSize, 28))
    idx = 0
    while True: # to make sure we never reach the end
        counter = 0
        while counter<=batchSize-1:
            idx = np.random.randint(len(fileNames)-1)
            try:
                data = h5py.File(fileNames[idx], 'r')
            except:
                print(idx, fileNames[idx])

            dataIdx = np.random.randint(200-1)
            if data['targets'][dataIdx][24] == branchNum:
                batchX[counter] = data['rgb'][dataIdx]
                batchY[counter] = data['targets'][dataIdx]
                counter += 1
        #data.close()
        yield (batchX, batchY)
    data.close()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/504030.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号