栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

我的NeRF-Pytroch

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

我的NeRF-Pytroch

第一章  代码

import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from run_nerf_helpers import *
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
from load_LINEMOD import load_LINEMOD_data

os.environ['CUDA_VISIBLE_DEVICES']='0'    #qin 设置在那个gpu上运行程序
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")  #q 应该等价于前两行
np.random.seed(0)      #设置相同的seed,每次生成的随机数相同
DEBUG = False

#  Constructs a version of 'fn' that applies to smaller batches.   #构造一个适用于较小批次的“fn”版本
def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.   #构造一个适用于较小批次的“fn”版本"""
    if chunk is None:    #chunk:同时处理的最大射线数。用于控制最大内存使用量。不影响最终结果。
        return fn
    def ret(inputs):
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

# Prepares inputs and applies network 'fn'        准备输入并应用网络fn
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'."""
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)   #??????????????????????
    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs

# Render rays in smaller minibatches to avoid OOM    以小批量渲染光线,以避免OOM
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM."""
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])
    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,near=0., far=1.,use_viewdirs=False, c2w_staticcam=None,**kwargs):
    """Render rays渲染光线
    Args: #参数。。。。。。。。。。。。。。。。。。。
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: 焦距. 针孔相机的焦距。
      chunk: int.         #同时处理的最大射线数。用于控制最大内存使用量。不影响最终结果。
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for each example in batch.  #批处理中每个示例的射线原点和方向。
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.                 #Camera-to-world变换矩阵
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.            #如果为真,表示在NDC坐标中的射线原点和方向
      near: float or array of shape [batch_size]. Nearest distance for a ray.            #射线的最近距离
      far: float or array of shape [batch_size]. Farthest distance for a ray.            #射线的最远距离
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.   #如果为真,则使用模型中空间点的观察方向。
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
                 camera while using other c2w argument for viewing directions.    #如果不是 None,则在使用其他 c2w 参数查看方向的同时,将此转换矩阵用于相机。

    Returns:  #返回。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.     #预测光线的RGB值
      disp_map: [batch_size]. Disparity map. Inverse of depth.    #视差图,深度的倒数
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.   #沿着光线积累不透明度(透明度)
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image    #渲染完整图像的特殊情况
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch    #使用提供的射线批次
        rays_o, rays_d = rays
    if use_viewdirs:                          #如果为真,则使用模型中空间点的观察方向。
        # provide ray directions as input     #提供光线方向作为输入
        viewdirs = rays_d
        if c2w_staticcam is not None:         #如果不是 None,则在使用其他 c2w 参数查看方向的同时,将此转换矩阵用于相机
            # special case to visualize effect of viewdirs   #可视化光线方向效果的特殊情况
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()
    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes  #对于前向场景
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)
    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]

def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor
    rgbs = []
    disps = []
    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        print(i, time.time() - t)
        t = time.time()
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)
        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """
        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)
    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps

# qin 实例化NeRF的多层感知机模型。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
def create_nerf(args):
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)   #args.multires使用3D数据:位置
    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)  #args.use_viewdirs使用5D数据:位置和方向
    output_ch = 5 if args.N_importance > 0 else 4   #N_importance=128
    skips = [4]
    model = NeRF(D=args.netdepth, W=args.netwidth,input_ch=input_ch, output_ch=output_ch, skips=skips,input_ch_views=input_ch_views,use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())
    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,input_ch=input_ch, output_ch=output_ch, skips=skips,input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())
    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,embed_fn=embed_fn,embeddirs_fn=embeddirs_fn,netchunk=args.netchunk)

    # Create optimizer    创建优化器
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))  #学习率lr=5e-4
    start = 0
    basedir = args.basedir
    expname = args.expname


    ##############################################################################################################################################
    # Load checkpoints    加载检查点  为粗网络重新加载权重检查点的 npy 文件
    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
    print('Found ckpts', ckpts)       #输出 Found ckpts []
    if len(ckpts) > 0 and not args.no_reload:    #这一步是在第一次运行出错或中断情况,第二次加载checkpoints文件。若第一次运行正常,没有checkpoints,这一步不执行
        ckpt_path = ckpts[-1]   #最后一个checkpoints tar文件
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)
        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])
    ###############################################################################################################################################

    render_kwargs_train = {'network_query_fn' : network_query_fn,'perturb' : args.perturb,'N_importance' : args.N_importance,'network_fine' : model_fine,
        'N_samples' : args.N_samples,'network_fn' : model,'use_viewdirs' : args.use_viewdirs,'white_bkgd' : args.white_bkgd,'raw_noise_std' : args.raw_noise_std,}

    # NDC only good for LLFF-style forward facing data   NDC只适用于llff风格的前向数据
    if args.dataset_type != 'llff' or args.no_ndc:    #llff格式图片不执行此段
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

# q 将模型的预测转换为语义上有意义的值。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    """
    将模型的预测转换为语义上有意义的值。
     参数:
         raw:[num_rays, num_samples along ray, 4]。 从模型预测。  #模型的预测值
         z_vals: [num_rays, num_samples along ray]。 积分时间。
         rays_d:[num_rays, 3]。 每条射线的方向。
     返回:
         rgb_map: [num_rays, 3]。 光线的估计 RGB 颜色。
         disp_map: [num_rays]。 视差图。 深度图的逆。
         acc_map: [num_rays]。 沿每条射线的权重总和。
         weights:[num_rays, num_samples]。 分配给每个采样颜色的权重。
         depth_map:[num_rays]。 到物体的估计距离。
    """
    raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists)
    dists = z_vals[...,1:] - z_vals[...,:-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]
    dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape) * raw_noise_std

        # Overwrite randomly sampled data if pytest
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise)
    alpha = raw2alpha(raw[...,3] + noise, dists)  # [N_rays, N_samples]
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
    depth_map = torch.sum(weights * z_vals, -1)
    disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    acc_map = torch.sum(weights, -1)
    if white_bkgd:
        rgb_map = rgb_map + (1.-acc_map[...,None])

    return rgb_map, disp_map, acc_map, weights, depth_map    #返回值

# 体渲染。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
def render_rays(ray_batch,network_fn,network_query_fn,N_samples,retraw=False,lindisp=False,perturb=0.,N_importance=0,network_fine=None,
                white_bkgd=False,raw_noise_std=0.,verbose=False,pytest=False):
    #qin 体渲染。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
    """体积渲染。.。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
    参数:
    ray_batch:形状数组[batch_size, ...]。沿射线采样所需的所有信息,包括:射线原点、射线方向、最小距离、最大距离和单位幅度观察方向。
    network_fn:函数。用于预测空间中每个点的RGB和密度的模型。
    network_query_fn:用于将查询传递给network_fn的函数。
    N_samples:整数。沿每条射线采样的不同次数。
    retraw:布尔值。如果为True,则包括模型的原始、未处理的预测。
    lindisp:布尔。如果为True,则在反深度而不是深度中线性采样。
    perturb:浮点数,0或1。如果非零,则在分层的随机时间点对每条射线进行采样。
    N_importance:整数。沿每条射线采样的额外次数。这些样本仅传递给network_fine。
    network_fine:与network_fn具有相同规格的“精细”网络。
    white_bkgd:布尔值。如果为True,则假定为白色背景。
    raw_noise_std: ...
    verbose:布尔。如果为True,则打印更多调试信息。

    返回:
    rgb_map: [num_rays, 3]。光线的估计RGB颜色。来自精品模型。
    disp_map: [num_rays]。视差图。 1 / 深度。
    acc_map: [num_rays]。沿每条射线的累积不透明度。来自精品模型。
    raw:[num_rays, num_samples, 4]。来自模型的原始预测。
    rgb0:参见rgb_map。粗略模型的输出。
    disp0:见disp_map。粗略模型的输出。
    acc0:见acc_map。粗略模型的输出。
    z_std:[num_rays]。每个样本沿射线的距离标准偏差。
    """

    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    z_vals = z_vals.expand([N_rays, N_samples])
    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)
        z_vals = lower + (upper - lower) * t_rand
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]

#     raw = run_network(pts)
    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
    if N_importance > 0:
        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()
        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
        run_fn = network_fn if network_fine is None else network_fine
#         raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, viewdirs, run_fn)
        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]
    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")
    return ret

# q 配置参数。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
def config_parser():
    import configargparse    # argparse是一个Python模块:命令行选项、参数和子命令解析器。让人轻松编写用户友好的命令行接口
    parser = configargparse.ArgumentParser()  #q 创建解析器/解析对象,  parser.add_argument添加参数
    parser.add_argument('--config', is_config_file=True, help='config file path')   #给parser实例添加一个-- confi对象
    parser.add_argument("--expname", type=str, help='experiment name')
    parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs')
    parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory')

    # training options
    parser.add_argument("--netdepth", type=int, default=8, help='layers in network')                #网络的层数
    parser.add_argument("--netwidth", type=int, default=256, help='channels per layer')               #每层的通道数
    parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network')           #精细网络中的层数
    parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network')  #精细网络中每层的通道数
    parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)')  #batch size一次训练所选取的样本数。
    parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate')     #学习率
    parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000 steps)')   #指数学习速率衰减(在1000步)
    parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') # chunk 并行处理的射线数,如果内存耗尽则减少
    parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') #通过网络并行发送的PTS数量,如果内存不足则减少
    parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time')     #no_batching每次只从一张图像中随机提取光线
    parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt')    #从保存的ckpt不重新加载权重
    parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network')    #为粗网络重新加载特定权重的 npy 文件

    # rendering options
    parser.add_argument("--N_samples", type=int, default=64,help='number of coarse samples per ray')            #每条射线的粗样本的数量
    parser.add_argument("--N_importance", type=int, default=0,help='number of additional fine samples per ray')   #每条射线的额外精细样品数量
    parser.add_argument("--perturb", type=float, default=1.,help='set to 0. for no jitter, 1. for jitter')      #set to 0. for no jitter, 1. for jitter
    parser.add_argument("--use_viewdirs", action='store_true',help='use full 5D input instead of 3D')             #使用全5D而不是3D输入
    parser.add_argument("--i_embed", type=int, default=0,help='set 0 for default positional encoding, -1 for none')    #embed内嵌、嵌入。设置0为默认位置编码,-1为无,此处为0
    parser.add_argument("--multires", type=int, default=10,help='log2 of max freq for positional encoding (3D location)')   #multires多精度、多分辨率,值为10。位置编码(3D位置)的最大频率的log2
    parser.add_argument("--multires_views", type=int, default=4,help='log2 of max freq for positional encoding (2D direction)')  #值为4, 位置编码(2D方向)的最大频率的log2
    parser.add_argument("--raw_noise_std", type=float, default=0.,help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
    parser.add_argument("--render_only", action='store_true',help='do not optimize, reload weights and render out render_poses path') #不优化,重新加载权重和渲染出render_poses路径
    parser.add_argument("--render_test", action='store_true',help='render the test set instead of render_poses path')  #渲染测试集而不是render_poses路径
    parser.add_argument("--render_factor", type=int, default=0,help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') #降低采样系数以加速渲染,设置4或8快速预览

    # training options
    parser.add_argument("--precrop_iters", type=int, default=0,help='number of steps to train on central crops')
    parser.add_argument("--precrop_frac", type=float,default=.5, help='fraction of img taken for central crops')

    # dataset options
    parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels')  #数据类型
    parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

    ## deepvoxels flags
    parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase')

    ## blender flags
    parser.add_argument("--white_bkgd", action='store_true',help='set to render synthetic data on a white bkgd (always use for dvoxels)')
    parser.add_argument("--half_res", action='store_true',help='load blender synthetic data at 400x400 instead of 800x800') #half_res以400x400而不是800x800加载blender合成数据

    ## llff flags
    parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images')  #LLFF图像的下采样因子
    parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') #不要使用标准化的设备坐标(设置为非正向场景)
    parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes')
    parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8')   将每1/N张图像作为LLFF测试集,文章中N=8

    # logging/saving options  #保存一些数据的频率
    parser.add_argument("--i_print",   type=int, default=100, help='frequency of console printout and metric loggin')   #控制台打印输出和指标记录的频率
    parser.add_argument("--i_img",     type=int, default=500, help='frequency of tensorboard image logging')   #张量板图像记录的频率
    parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving')    #权重检查点保存频率
    parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving')   #测试机保存频率
    parser.add_argument("--i_video",   type=int, default=50000, help='frequency of render_poses video saving')  #render_poses视频保存频率

    return parser

#q 训练。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
def train():
    parser = config_parser()  #获取整体的参数
    args = parser.parse_args()  #属性给与args实例

    # Load data  加载数据。。。。。。。。。。。。。。。。。
    K = None  #K是什么???????????????????????
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,recenter=True, bd_factor=.75,spherify=args.spherify)
        hwf = poses[0,:3,-1]
        poses = poses[:,:3,:4]
        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)  #(20,3,5) (20,378,504,3) (20,2)
        if not isinstance(i_test, list):
            i_test = [i_test]
        if args.llffhold > 0:     #args.llffhold=8
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]
        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])
        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.
        else:   #执行以下部分
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split
        near = 2.
        far = 6.
        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]
    elif args.dataset_type == 'LINEMOD':
        images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
        print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
        print(f'[CHECK HERE] near: {near}, far: {far}.')
        i_train, i_val, i_test = i_split
        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]
    elif args.dataset_type == 'deepvoxels':
        images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape,basedir=args.datadir,testskip=args.testskip)
        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split
        hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.
    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types  将intrinsic转换为正确的类型
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]
    if K is None:
        K = np.array([[focal, 0, 0.5*W],[0, focal, 0.5*H],[0, 0, 1]])   #K的结果:[[array([407.5658], dtype=float32) 0 252.0] [0 array([407.5658], dtype=float32) 189.0][0 0 1]]
    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file  创建日志目录并复制配置文件
    basedir = args.basedir
    expname = args.expname   #expname的路径值在哪一步被告知的????????????????????
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}n'.format(arg, attr))   #将参数写入args.txt文件
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')  #config.txt文件为参数的设置
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model  创建nerf模型。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)  #通过creat_nerf()函数获得一系列参数
    global_step = start
    bds_dict = {'near' : near,'far' : far,}   #0和1
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU   将测试数据移动到GPU。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    if args.render_only:   #为什么不执行此段?????????????????????action='store_true'表示什么???????????????
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None
            testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start))  #q  os.path.join()函数:连接两个或更多的路径名组件
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)
            rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

            return

    # Prepare raybatch tensor if batching random rays   如果批处理随机射线,则制备射线批处理张量
    N_rand = args.N_rand   #batch size一次训练所选取的样本数,32*32*4,
    use_batching = not args.no_batching    #no_batching每次只从一张图像中随机提取光线,那not args.no_batching表示什么呢???
    if use_batching:
        # For random ray batching   用于随机射线批处理
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    # Move training data to GPU   将训练数据移动到GPU。。。。。。。。。。。。。。。。。。。。
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    N_iters = 200000 + 1     #q 初始为200000。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。。
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images 随机所有图像
            batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3,:4]
            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)
                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(torch.meshgrid(torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")                
                else:
                    coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)  # (H, W, 2)
                coords = torch.reshape(coords, [-1,2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,verbose=i < 10, retraw=True,**render_kwargs_train)
        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)
        loss.backward()
        optimizer.step()

        # NOTE: importANT!  注意:重要的!
        ###   update learning rate   ###  更新学习率。。。。。。。。。。。。。。。。。。
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################
        dt = time.time()-time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####

        # Rest is logging
        if i%args.i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({'global_step': global_step,'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),'optimizer_state_dict': optimizer.state_dict(),}, path)
            print('Saved checkpoints at', path)
        if i%args.i_video==0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

            # if args.use_viewdirs:
            #     render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
            #     with torch.no_grad():
            #         rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
            #     render_kwargs_test['c2w_staticcam'] = None
            #     imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

        if i%args.i_testset==0 and i > 0:
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')
        if i%args.i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
        """
            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
                tf.contrib.summary.scalar('loss', loss)
                tf.contrib.summary.scalar('psnr', psnr)
                tf.contrib.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.contrib.summary.scalar('psnr0', psnr0)
            if i%args.i_img==0:
                # Log a rendered validation view to Tensorboard
                img_i=np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3,:4]
                with torch.no_grad():
                    rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,**render_kwargs_test)
                psnr = mse2psnr(img2mse(rgb, target))
                with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                    tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
                    tf.contrib.summary.scalar('psnr_holdout', psnr)
                    tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
                if args.N_importance > 0:
                    with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
                        tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
                        tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
                        tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
        """

        global_step += 1

if __name__=='__main__':
    # q设置pytorch中默认的浮点类型,为GPU tensor,32bit浮点
    #torch.set_default_tensor_type('torch.cuda.FloatTensor')

    train()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/487331.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号