参考代码:PyTorch implementation for Local Context Normalization: Revisiting Local Normalization
参考文章:Local Context Normalization: Revisiting Local Normalization
代码实现的是torch的code,以及是对2D图像的LCN,笔者改写成了tensorflow 1.4的code以及3D 图像。
Codeimport tensorflow as tf
import keras
from keras.layers.core import Layer
import math
import os
import numpy as np
class LocalContextNorm(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size = 128):
super(LocalContextNorm, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.img_size <= self.window_size[0]:
raise Exception('Window size must be smaller than image size in LCN case!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
inputs_shape = inputs.shape.as_list()
C = inputs_shape[1]
D, H, W = self.img_size, self.img_size, self.img_size
G = math.floor(C / self.channels_per_group)
assert C % self.channels_per_group == 0
def use_window(inputs):
inputs_sq = inputs * inputs
integral_img = tf.cumsum(tf.cumsum(tf.cumsum(inputs, axis=2), axis=3), axis=4)
integral_img_sq = tf.cumsum(tf.cumsum(tf.cumsum(inputs_sq, axis=2), axis=3), axis=4)
d = [self.window_size[0], self.window_size[1], self.window_size[2]]
kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
#[[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]]
c_kernel = np.ones((self.channels_per_group, 1, 1)).tolist()
# integral_img
'''sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(integral_img))'''
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img[:, 0, :, :, :], dim=1)))
for i in range(1, C):
temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img[:, i, :, :, :], dim=1)))
sums = tf.concat([sums, temp], axis=1)
'''integral_img = tf.reshape(integral_img, [-1, 1, C*D, H, W])
d = [self.window_size[0], self.window_size[1], self.window_size[2]]
sums = tf.stop_gradient(
keras.layers.Conv3D(input_shape=[-1, 1, C*D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1],
dilation_rate=d,
data_format='channels_first')(integral_img))
sums = tf.reshape(sums, [-1, C, sums.shape.as_list()[2] // C, sums.shape.as_list()[3], sums.shape.as_list()[4]])'''
temp_shape = self.img_size - self.window_size[0]
sums = tf.expand_dims(tf.reshape(sums, [-1, C, temp_shape, temp_shape*temp_shape]), axis=1)
sums = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_shape, temp_shape*temp_shape],
filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
data_format='channels_first')(sums))
assert G == sums.shape.as_list()[2]
sums = tf.squeeze(tf.reshape(sums, [-1, 1, G, temp_shape, temp_shape, temp_shape]), squeeze_dims=1) # [B, G, ., ., .]
# integral_img_sq
'''squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, C, D, H, W], filters=C, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(integral_img_sq))'''
squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img_sq[:, 0, :, :, :], dim=1)))
for i in range(1, C):
temp = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, D, H, W], filters=1, kernel_size=2, padding='valid',
kernel_initializer=keras.initializers.Constant(kernel), strides=[1, 1, 1], dilation_rate=d,
data_format='channels_first')(tf.expand_dims(integral_img_sq[:, i, :, :, :], dim=1)))
squares = tf.concat([squares, temp], axis=1)
temp_squares_shape = self.img_size - self.window_size[0]
squares = tf.expand_dims(tf.reshape(squares, [-1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape]), axis=1)
squares = tf.stop_gradient(keras.layers.Conv3D(input_shape=[-1, 1, C, temp_squares_shape, temp_squares_shape*temp_squares_shape],
filters=1, kernel_size=[self.channels_per_group, 1, 1], padding='valid',
kernel_initializer=keras.initializers.Constant(c_kernel), strides=[self.channels_per_group, 1, 1],
data_format='channels_first')(squares))
assert G == squares.shape.as_list()[2]
squares = tf.squeeze(tf.reshape(squares, [-1, 1, G, temp_squares_shape, temp_squares_shape, temp_squares_shape]), squeeze_dims=1) # [B, G, ., ., .]
n = self.window_size[0] * self.window_size[1] * self.window_size[2] * self.channels_per_group
means = sums / n
var = 1.0 / n * (squares - sums * sums / n)
d, h, w = temp_shape, temp_shape, temp_shape
pad3d = [int(math.floor((D - d) / 2)), int(math.ceil((D - d) / 2)), int(math.floor((H - h) / 2)),
int(math.ceil((H - h) / 2)), int(math.floor((W - w) / 2)), int(math.ceil((W - w) / 2))]
padded_means = tf.pad(means, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT')
padded_vars = tf.pad(var, [[0, 0], [0, 0], [pad3d[0], pad3d[1]], [pad3d[2], pad3d[3]], [pad3d[4], pad3d[5]]], 'REFLECT') + self.eps
temp = (inputs[:, :self.channels_per_group, :, :, :] -
tf.expand_dims(padded_means[:, 0, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, 0, :, :, :], dim=1)))
for i in range(1, G):
t_temp = (inputs[:, i * self.channels_per_group:i * self.channels_per_group + self.channels_per_group, :, :, :] -
tf.expand_dims(padded_means[:, i, :, :, :], dim=1)) / tf.sqrt((tf.expand_dims(padded_vars[:, i, :, :, :], dim=1)))
temp = tf.concat([temp, t_temp], axis=1)
inputs = temp
return inputs
#inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = tf.transpose(inputs, [0, 2, 3, 4, 1])
return inputs * self.weight + self.bias
def compute_output_shape(self, input_shape):
return input_shape
class GroupContextNorm(Layer): # 3D
def __init__(self, channels_per_group=1, window_size=(9, 9, 9), eps=1e-5, img_size=128):
super(GroupContextNorm, self).__init__()
self.channels_per_group = channels_per_group
self.eps = eps
self.window_size = window_size #[D, H, W]
self.img_size = img_size
def build(self, input_shape):
if len(input_shape) != 5: #[B, H, W, D, C]
raise Exception('Input of LCN layer should have 5 dims with [B, D, H, W, C]!')
if self.img_size > self.window_size[0]:
raise Exception('Window size must be large than image size in GN case!')
self.num_features = input_shape[-1]
self.weight = tf.Variable(tf.ones([1, 1, 1, 1, self.num_features]), trainable=True)
self.bias = tf.Variable(tf.zeros([1, 1, 1, 1, self.num_features]), trainable=True)
#self.weight = tf.ones([1, 1, 1, 1, self.num_features])
#self.bias = tf.zeros([1, 1, 1, 1, self.num_features])
self.built = True
def call(self, inputs, **kwargs):
inputs = tf.transpose(inputs, [0, 4, 1, 2, 3])
inputs_shape = inputs.shape.as_list()
_, C, D, H, W = inputs_shape[0], inputs_shape[1], inputs_shape[2],
inputs_shape[3], inputs_shape[4]
G = math.floor(C / self.channels_per_group)
assert C % self.channels_per_group == 0
img_size = self.img_size
def no_use_window(inputs):
inputs_shape = inputs.shape.as_list()
inputs = tf.reshape(inputs, [-1, G, inputs_shape[1] // G * img_size * img_size * img_size])
means, var = tf.nn.moments(inputs, [2], keep_dims=True)
inputs = tf.reshape((inputs - means) / tf.sqrt(var + self.eps), [-1, C, img_size, img_size, img_size])
return inputs
#inputs = use_window(inputs) if self.window_size[0] < D else no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = no_use_window(inputs) # 注意这里默认 D, H, W相等,且inputs的各维度也相等
inputs = tf.transpose(inputs, [0, 2, 3, 4, 1])
return inputs * self.weight + self.bias
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#matrix = tf.concat([tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1]), tf.ones([2, 7, 7, 7, 1]), 2*tf.ones([2, 7, 7, 7, 1])], axis=-1)
matrix = tf.cumsum((tf.cumsum(tf.cumsum(tf.ones([2, 7, 7, 7, 4]), axis=1), axis=2)), axis=3)
lcn_layer = LocalContextNorm(channels_per_group=2, window_size=[3, 3, 3], img_size=matrix.shape.as_list()[1])
matrix_after_lcn = lcn_layer(matrix)
'''gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
tf.global_variables_initializer().run(session=sess)'''
#print(sess.run(matrix_after_lcn))
#print(sess.run(matrix_after_lcn[:, :, :, :, 0]))
#print(sess.run(matrix_after_lcn[:, :, :, :, 1]))
#print(matrix_after_lcn.shape.as_list())
Some Notes
1、keras 报错 ‘_TfDeviceCaptureOp’ object has no attribute ‘type’,tf1.4对应keras2.0.8,重装keras=2.0.8后解决。
2、tf.cond无法完成建图,故上述代码写了两个class。
3、tf.nn.layer以及tflearn.layers中在tf1.4版本下的conv3d没有dilation_rate参数,故使用keras.layer.Conv3D。
4、二维积分图运算在参考代码中使用了三维卷积,故三维积分图需要使用四维卷积,但是没有3维卷积以上API,故采用Conv3D + tf.concat的策略。
5、三维积分图运算矩阵为 kernel = [[[-1, 1], [1, -1]], [[1, -1], [-1, 1]]],可自行推导,二维积分图运算矩阵参考上述参考代码。
6、tf.pad无法完成比input size还大的padding,故use_window的window_size最大为img_size一半。



