栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

AE(associative embedding) loss TF版本实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

AE(associative embedding) loss TF版本实现

        主要两个修改点

        a.第一版输入tag0/tag1来自两个不同feature map 我这边用的是同一个feature map 并且支持任意关键点数量。

        b.第一版按CornerNet的实现 当出现单个目标时存在bug(coco数据中大多是多目标)。理论上push值应为0 但实际不是 主要修复该bug。

        

# y_pred: (b, h, w, 1)
# mask (b, max_objects) , [[1,0,0...],[1,1,1,0,0...]....]
# indices (b, max_objects*4) , [[312,555,666,777,0,...],[........]....]
# output : (b, max_objects, 4) [[1.0, 2.0, 3.0, 4.0],.....]
def trans_objects(y_pred, mask, indices):
 #batch, channel
 b, c tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
 #max_objects
 max_corners tf.shape(indices)[1]
 max_objs max_corners / 4
 y_pred tf.reshape(y_pred, (b, -1, c)) #b, h*w, c 1
 length tf.shape(y_pred)[1] #h*w
 indices tf.cast(indices, tf.int32) #b, max_objs*4
 #y_pred b, h*w, c - b, max_objs, 2 只计算indices指定坐标的loss
 batch_idx tf.expand_dims(tf.range(0, b), 1) #b, 1
 batch_idx tf.tile(batch_idx, (1, max_corners)) #b, max_corners. [[0,0,0..],[1,1,1...],...]
 full_indices (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) tf.reshape(indices, [-1])) #b*max_objs*4. [0 312,0,....,h*w 65,h*w 203,h*w 1105,h*w 0....,2*h*w ?,2*h*w 0,.......]
 y_pred tf.gather(tf.reshape(y_pred, [-1, c]), full_indices)#根据full_indices 在b*h*w维度中筛选,
 y_pred tf.reshape(y_pred, [b, -1, 4]) # [b*max_objs*4, c 1] - [b, max_objs, 4]
 #mask
 mask tf.tile(tf.expand_dims(mask, axis -1), (1, 1, 4)) # b,max_objs - b,max_objs,1 - b,max_objs,4。 4个角点
 return y_pred * mask
# y_pred: embedding, (batch_size, out_h, out_w, 1) 
# mask (batch_size, max_objects) , [[1,0,0...],[1,1,1,0,0...]....]
# indices 4个角点索引 (batch_size, max_objects*4) , [[312,555,666,777,0,...],[........]....]
def my_ae_loss(y_pred, mask, indices):
 # b,h,w,1 - b, n, 4
 tag trans_objects(y_pred, mask, indices)
 max_objs tf.shape(mask)[1] # n
 num tf.reduce_sum(mask, axis 1, keepdims True) # b,n - b,1 . sum(n) 每个batch的num可能不一样
 tag_mean tf.reduce_mean(tag, axis -1, keepdims True) # b,n,4 - b,n,1 .同一object内的embedding均值
 pull tf.pow(tag - tag_mean, 2) # 每个点减去均值。 b,n,4。 tag_mean b,n,1(自动扩4)
 pull pull / tf.expand_dims(tf.tile(num 1e-4, (1, max_objs)), axis -1)
 pull pull * tf.expand_dims(mask, axis -1) # mask内统计。 b,n,4。 mask b,n - b,n,1 - b,n,4
 pull tf.reduce_sum(pull)
 #print(pull.shape)
 #push计算(类外差)
 mask mask * tf.cast(tf.greater(num, 1), tf.float32) # by lvjj, num 1的mask 直接置0 其pull值会为0(与num 0一样)
 mask tf.expand_dims(mask, 1) tf.expand_dims(mask, 2)
 mask tf.equal(mask, 2) # b,n,n
 num tf.expand_dims(num, 2) # b, 1 - b, 1, 1
 num2 (num - 1) * num
 tag_mean tf.squeeze(tag_mean,axis -1) # b,n,1 - b,n
 dist tf.expand_dims(tag_mean, 1) - tf.expand_dims(tag_mean, 2) # b,n,n。 混淆矩阵 (b,i,j)表示batch中i类j类的差
 dist 1 - tf.abs(dist) # 这里对角线会出现1 不应该包含在后续sum计算中
 dist tf.nn.relu(dist)
 dist dist - 1 / (num 1e-4) # 抵消对角线1的影响 扣掉后才会与paper公式一致
 dist dist / (num2 1e-4)
 mask tf.cast(mask, tf.float32)
 push tf.reduce_sum(dist*mask)
 return pull, push

        测试代码 mask[1,1] 表示有两个目标 mask[1,0]表示一个目标 后者在原实现中会有问题

# 2,2,2,1 . embedding数值
y_pred tf.constant([ [[[1.],[2.]], 
 [[1.],[2.]],],
 [[[2.],[1.]], 
 [[2.],[1.]],],
 ], dtype tf.float32)
# 2, max_objs 2 . 
mask tf.constant([ [1,1],
 [1,1],
 ], dtype tf.float32)
# 2, max_objs*4 8
indices tf.constant([ [0,0,2,2, 1,1,3,3],
 [0,0,2,2, 1,1,3,3],
 ], dtype tf.float32)
my_ae_loss(y_pred, mask, indices)
 ( tf.Tensor: shape (), dtype float32, numpy 0.0 ,
 tf.Tensor: shape (), dtype float32, numpy 9.9897385e-05 )

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/268120.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号