栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tensorflow实现Local Context Normalization

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tensorflow实现Local Context Normalization

tensorflow实现Local Context Normalization

参考代码:PyTorch implementation for Local Context Normalization: Revisiting Local Normalization
参考文章:Local Context Normalization: Revisiting Local Normalization

代码实现的是torch的code,以及是对2D图像的LCN,笔者改写成了tensorflow 1.4的code以及3D 图像。

Code
import tensorflow as tf
import keras
from keras.layers.core import Layer
import math
import os
import numpy as np


class LocalContextNorm(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size = 128):
        super(LocalContextNorm, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.img_size <= self.window_size[0]:
            raise Exception('Window size must be smaller than image size in LCN case!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
        inputs_shape = inputs.shape.as_list()
        C = inputs_shape[1]
        D, H, W = self.img_size, self.img_size, self.img_size

        G = math.floor(C / self.channels_per_group)

        assert C % self.channels_per_group == 0

        def use_window(inputs):

            inputs_sq = inputs * inputs
            integral_img = tf.cumsum(tf.cumsum(tf.cumsum(inputs, axis=2), axis=3), axis=4)
            integral_img_sq = tf.cumsum(tf.cumsum(tf.cumsum(inputs_sq, axis=2), axis=3), axis=4)

            d = [self.window_size[0], self.window_size[1], self.window_size[2]]
            kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
                     #[[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
            c_kernel = np.ones((self.channels_per_group, 1, 1)).tolist()

            # integral_img
            '''sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(integral_img))'''
            sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                            kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img[:, 0, :, :, :], dim=1)))
            for i in range(1, C):
                temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                            kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img[:, i, :, :, :], dim=1)))
                sums = tf.concat([sums, temp], axis=1)

            '''integral_img = tf.reshape(integral_img, [-1, 1, C*D, H, W])
            d = [self.window_size[0], self.window_size[1], self.window_size[2]]
            sums = tf.stop_gradient(
                keras.layers.Conv3D(input_shape=[-1, 1, C*D, H, W], filters=1, kernel_size=2, padding='valid',
                                    kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1],
                                    dilation_rate=d,
                                    data_format='channels_first')(integral_img))
            sums = tf.reshape(sums, [-1, C, sums.shape.as_list()[2] // C, sums.shape.as_list()[3], sums.shape.as_list()[4]])'''

            temp_shape = self.img_size - self.window_size[0]
            sums = tf.expand_dims(tf.reshape(sums, [-1, C, temp_shape, temp_shape*temp_shape]), axis=1)
            sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_shape, temp_shape*temp_shape],
                                                        filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
                                                        data_format='channels_first')(sums))
            assert  G == sums.shape.as_list()[2]
            sums = tf.squeeze(tf.reshape(sums, [-1, 1, G, temp_shape, temp_shape, temp_shape]), squeeze_dims=1) # [B, G, ., ., .]

            # integral_img_sq
            '''squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(integral_img_sq))'''

            squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img_sq[:, 0, :, :, :], dim=1)))
            for i in range(1, C):
                temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
                                                        data_format='channels_first')(tf.expand_dims(integral_img_sq[:, i, :, :, :], dim=1)))
                squares = tf.concat([squares, temp], axis=1)

            temp_squares_shape = self.img_size - self.window_size[0]
            squares = tf.expand_dims(tf.reshape(squares, [-1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape]), axis=1)
            squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape],
                                                        filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
                                                        kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
                                                        data_format='channels_first')(squares))
            assert  G == squares.shape.as_list()[2]
            squares = tf.squeeze(tf.reshape(squares, [-1, 1, G, temp_squares_shape, temp_squares_shape, temp_squares_shape]), squeeze_dims=1) # [B, G, ., ., .]


            n = self.window_size[0] * self.window_size[1] * self.window_size[2] * self.channels_per_group
            means = sums / n
            var = 1.0 / n * (squares - sums * sums / n)
            d, h, w = temp_shape, temp_shape, temp_shape

            pad3d = [int(math.floor((D - d) / 2)), int(math.ceil((D - d) / 2)), int(math.floor((H - h) / 2)),
                     int(math.ceil((H - h) / 2)), int(math.floor((W - w) / 2)), int(math.ceil((W - w) / 2))]
            padded_means = tf.pad(means, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT')
            padded_vars = tf.pad(var, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT') + self.eps

            temp = (inputs[:, :self.channels_per_group, :, :, :] -
                     tf.expand_dims(padded_means[:, 0, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, 0, :, :, :], dim=1)))
            for i in range(1, G):
                t_temp = (inputs[:, i * self.channels_per_group:i * self.channels_per_group + self.channels_per_group, :, :, :] -
                     tf.expand_dims(padded_means[:, i, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, i, :, :, :], dim=1)))
                temp = tf.concat([temp, t_temp], axis=1)

            inputs = temp
            return inputs

        #inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs = use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs =  tf.transpose(inputs, [0, 2, 3, 4, 1])

        return inputs * self.weight + self.bias


    def compute_output_shape(self, input_shape):
        return input_shape

class GroupContextNorm(Layer):  # 3D
    def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size=128):
        super(GroupContextNorm, self).__init__()
        self.channels_per_group = channels_per_group
        self.eps = eps
        self.window_size = window_size  #[D, H, W]
        self.img_size = img_size

    def build(self, input_shape):

        if len(input_shape) != 5: #[B, H, W, D, C]
            raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')

        if self.img_size > self.window_size[0]:
            raise Exception('Window size must be large than image size in GN case!')

        self.num_features = input_shape[-1]
        self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
        self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
        #self.weight = tf.ones([1, 1, 1, 1, self.num_features])
        #self.bias = tf.zeros([1, 1, 1, 1, self.num_features])

        self.built = True

    def call(self, inputs, **kwargs):

        inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
        inputs_shape = inputs.shape.as_list()
        _, C, D, H, W  = inputs_shape[0], inputs_shape[1], inputs_shape[2], 
                        inputs_shape[3], inputs_shape[4]
        G = math.floor(C / self.channels_per_group)

        assert C % self.channels_per_group == 0

        img_size = self.img_size

        def no_use_window(inputs):
            inputs_shape = inputs.shape.as_list()
            inputs = tf.reshape(inputs, [-1, G, inputs_shape[1] // G * img_size * img_size * img_size])
            means, var = tf.nn.moments(inputs, [2], keep_dims=True)
            inputs = tf.reshape((inputs - means) / tf.sqrt(var + self.eps), [-1, C, img_size, img_size, img_size])

            return inputs

        #inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs = no_use_window(inputs)  # 注意这里默认 D, H, W相等,且inputs的各维度也相等
        inputs =  tf.transpose(inputs, [0, 2, 3, 4, 1])

        return inputs * self.weight + self.bias

if __name__ == '__main__':

    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    #matrix = tf.concat([tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1]), tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1])], axis=-1)
    matrix = tf.cumsum((tf.cumsum(tf.cumsum(tf.ones([2, 7, 7, 7, 4]), axis=1), axis=2)), axis=3)
    lcn_layer = LocalContextNorm(channels_per_group=2, window_size=[3, 3, 3], img_size=matrix.shape.as_list()[1])

    matrix_after_lcn = lcn_layer(matrix)

    '''gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    tf.global_variables_initializer().run(session=sess)'''

    #print(sess.run(matrix_after_lcn))
    #print(sess.run(matrix_after_lcn[:, :, :, :, 0]))
    #print(sess.run(matrix_after_lcn[:, :, :, :, 1]))
    #print(matrix_after_lcn.shape.as_list())
Some Notes

1、keras 报错 ‘_TfDeviceCaptureOp’ object has no attribute ‘type’,tf1.4对应keras2.0.8,重装keras=2.0.8后解决。
2、tf.cond无法完成建图,故上述代码写了两个class。
3、tf.nn.layer以及tflearn.layers中在tf1.4版本下的conv3d没有dilation_rate参数,故使用keras.layer.Conv3D。
4、二维积分图运算在参考代码中使用了三维卷积,故三维积分图需要使用四维卷积,但是没有3维卷积以上API,故采用Conv3D + tf.concat的策略。
5、三维积分图运算矩阵为 kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]],可自行推导,二维积分图运算矩阵参考上述参考代码。
6、tf.pad无法完成比input size还大的padding,故use_window的window_size最大为img_size一半。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/664727.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号