栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Metric learning

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Metric learning

一:减小类内距离,增大类间距离

class Metric_loss(nn.Module):
    def __init__(self,src_class):
        super(Metric_loss, self).__init__()

        self.n_class=src_class


    def forward(self, s_feature,s_labels):

        n, d = s_feature.shape

        # get labels


        # image number in each class
        ones = torch.ones_like(s_labels, dtype=torch.float)
        zeros = torch.zeros(self.n_class)

        zeros = zeros.cuda()

        s_n_classes = zeros.scatter_add(0, s_labels, ones)


        # image number cannot be 0, when calculating centroids
        ones = torch.ones_like(s_n_classes)
        s_n_classes = torch.max(s_n_classes, ones)


        # calculating centroids, sum and divide
        zeros = torch.zeros(self.n_class, d)

        zeros = zeros.cuda()
        s_sum_feature = zeros.scatter_add(0, torch.transpose(s_labels.repeat(d, 1), 1, 0), s_feature)

        s_centroid = torch.div(s_sum_feature, s_n_classes.view(self.n_class, 1))


        # calculating inter distance

        temp = torch.zeros((n, d)).cuda()

        for i in range(n):
            temp[i] = s_centroid[s_labels[i]]

        s_all_centroid=s_centroid.sum(axis=0)/self.n_class

        s_all_centroid=s_all_centroid.repeat(self.n_class, 1)


        # inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).max()
        #
        # intra_loss = torch.norm(temp-s_feature, p=1, dim=0).max()

        inter_loss = torch.norm(s_all_centroid- s_centroid, p=1, dim=0).sum()

        intra_loss = torch.norm(temp-s_feature, p=1, dim=0).sum()



        inter_loss = inter_loss/d
        intra_loss=intra_loss/d

        return inter_loss,intra_loss

二:三元组损失

selector = BatchHardTripletSelector()
anchor, pos, neg = selector(feature, src_label)
triplet_loss = TripletLoss(margin=1).cuda()
triplet = triplet_loss(anchor, pos, neg)
class TripletLoss(nn.Module):
    '''
    Compute normal triplet loss or soft margin triplet loss given triplets
    '''
    def __init__(self, margin = None):
        super(TripletLoss, self).__init__()
        self.margin = margin
        if self.margin is None:  # use soft-margin
            self.Loss = nn.SoftMarginLoss()
        else:
            self.Loss = nn.TripletMarginLoss(margin = margin, p = 2)

    def forward(self, anchor, pos, neg):
        if self.margin is None:
            num_samples = anchor.shape[0]
            y = t.ones((num_samples, 1)).view(-1)
            if anchor.is_cuda: y = y.cuda()
            ap_dist = t.norm(anchor - pos, 2, dim = 1).view(-1)
            an_dist = t.norm(anchor - neg, 2, dim = 1).view(-1)
            loss = self.Loss(an_dist - ap_dist, y)
        else:
            loss = self.Loss(anchor, pos, neg)

        return loss

class BatchHardTripletSelector(object):
    '''
    a selector to generate hard batch embeddings from the embedded batch
    '''
    def __init__(self, *args, **kwargs):
        super(BatchHardTripletSelector, self).__init__()

    def __call__(self, embeds, labels):
        dist_mtx = pdist_torch(embeds, embeds).detach().cpu().numpy()# 计算距离
        labels = labels.contiguous().cpu().numpy().reshape((-1, 1))# 断开连接,深拷贝
        num = labels.shape[0]
        dia_inds = np.diag_indices(num)#返回对角线索引
        lb_eqs = labels == labels.T
        lb_eqs[dia_inds] = False
        dist_same = dist_mtx.copy()
        dist_same[lb_eqs == False] = -np.inf #负正无穷大的浮点表示
        pos_idxs = np.argmax(dist_same, axis = 1)
        dist_diff = dist_mtx.copy()
        lb_eqs[dia_inds] = True
        dist_diff[lb_eqs == True] = np.inf
        neg_idxs = np.argmin(dist_diff, axis = 1)
        pos = embeds[pos_idxs].contiguous().view(num, -1)
        neg = embeds[neg_idxs].contiguous().view(num, -1)
        return embeds, pos, neg

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/350428.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号