栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

211112-多维数据分布MMD相似性计算Demo

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

211112-多维数据分布MMD相似性计算Demo

  • 运行结果

  • main函数

import matplotlib.pyplot as plt
import numpy as np

mean1 = [0, 0]
cov1 = [[5, 0], [-5, 10]]  # diagonal covariance

mean2 = [10, 0]
cov2 = [[5, 0], [5, 10]]  # diagonal covariance

x1, y1 = np.random.multivariate_normal(mean1, cov1, 5000).T
x2, y2 = np.random.multivariate_normal(mean2, cov2, 5000).T
plt.scatter(x1, y1, c='r', alpha=0.1)
plt.scatter(x2, y2, c='b', alpha=0.1)
plt.axis('equal')
plt.show()

source = np.vstack([x1, y1]).T
target = np.vstack([x2, y2]).T


from utils.mmd_numpy_sklearn import mmd_linear, mmd_poly, mmd_rbf

def cal_dis(source, target, function):
    d1 = function(source, target)
    d2 = function(target, source)
    d3 = function(source, source)
    d4 = function(target, target)
    print('D1-{:.2f}; D2-{:.2f}; D3-{:.2f}; D4-{:.2f}.'.format(d1, d2, d3, d4))

cal_dis(source, target, mmd_linear)
cal_dis(source, target, mmd_poly)
cal_dis(source, target, mmd_rbf)
  • mmd函数
# Compute MMD (maximum mean discrepancy) using numpy and scikit-learn.

import numpy as np
from sklearn import metrics


def mmd_linear(X, Y):
    """MMD using linear kernel (i.e., k(x,y) = )
    Note that this is not the original linear MMD, only the reformulated and faster version.
    The original version is:
        def mmd_linear(X, Y):
            XX = np.dot(X, X.T)
            YY = np.dot(Y, Y.T)
            XY = np.dot(X, Y.T)
            return XX.mean() + YY.mean() - 2 * XY.mean()

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Returns:
        [scalar] -- [MMD value]
    """
    delta = X.mean(0) - Y.mean(0)
    return delta.dot(delta.T)


def mmd_rbf(X, Y, gamma=1.0):
    """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        gamma {float} -- [kernel parameter] (default: {1.0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.rbf_kernel(X, X, gamma)
    YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)
    XY = metrics.pairwise.rbf_kernel(X, Y, gamma)
    return XX.mean() + YY.mean() - 2 * XY.mean()


def mmd_poly(X, Y, degree=2, gamma=1, coef0=0):
    """MMD using polynomial kernel (i.e., k(x,y) = (gamma  + coef0)^degree)

    Arguments:
        X {[n_sample1, dim]} -- [X matrix]
        Y {[n_sample2, dim]} -- [Y matrix]

    Keyword Arguments:
        degree {int} -- [degree] (default: {2})
        gamma {int} -- [gamma] (default: {1})
        coef0 {int} -- [constant item] (default: {0})

    Returns:
        [scalar] -- [MMD value]
    """
    XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0)
    YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0)
    XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0)
    return XX.mean() + YY.mean() - 2 * XY.mean()


if __name__ == '__main__':
    a = np.arange(1, 10).reshape(3, 3)
    b = [[7, 6, 5], [4, 3, 2], [1, 1, 8], [0, 2, 5]]
    b = np.array(b)
    print(a)
    print(b)
    print(mmd_linear(a, b))  # 6.0
    print(mmd_rbf(a, b))  # 0.5822
    print(mmd_poly(a, b))  # 2436.5

  • 参考文献
  1. https://numpy.org/doc/stable/reference/random/generated/numpy.random.multivariate_normal.html
  2. https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_numpy_sklearn.py
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/487562.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号