栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

DJL-Java开发者动手学深度学习之归一化处理及源代码

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

DJL-Java开发者动手学深度学习之归一化处理及源代码

在深度学习训练中,通过会对数据进行归一化处理。通常讲,归一化有两点好处:
1、使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
2、加快学习算法的收敛速度。

MinMax归一化

将数据缩放到0和1之间,公式如下:
Y = X i − m i n ( X i ) m a x ( X i ) − m i n ( x i ) Y = frac{X_i - min(X_i)}{max(X_i) - min(x_i)} Y=max(Xi​)−min(xi​)Xi​−min(Xi​)​

标准归一化

将数据所防伪均值是0,方差为1的状态,公式如下:
Y = X i − μ δ Y = frac{X_i - mu}{delta} Y=δXi​−μ​
其缩放结果为:

均值归一化

将数据缩放到-1和1之间,公式如下:
Y = X i − m e a n ( X ) m a x ( X ) − m i n ( X ) Y = frac{X_i - mean(X)}{max(X) - min(X)} Y=max(X)−min(X)Xi​−mean(X)​

MinMax源代码解析

归一化的公式相对比较简单,只要照着公式去实现即可。
如简单的Python代码实现如下:

# 计算train数据集的最大值,最小值,平均值
maximums, minimums  = training_data.max(axis=0), training_data.min(axis=0)

# 对数据进行归一化处理

for i in range(feature_num):
data[:, i] = (data[:, i] - minimums[i]) / (maximums[i] - minimums[i])

Java的实现相比Python,代码多一点,不过实现内容一样,下面为DJL框架里封装的源代码,详细如下:

public class MinMaxScaler implements AutoCloseable {

    private NDArray fittedMin;
    private NDArray fittedMax;
    private NDArray fittedRange;
    private float minRange;
    private float maxRange = 1f;
    private boolean detached;

    public MinMaxScaler fit(NDArray data, int[] axises) {
        fittedMin = data.min(axises);
        fittedMax = data.max(axises);
        fittedRange = fittedMax.sub(fittedMin);
        if (detached) {
            detach();
        }
        return this;
    }

    public MinMaxScaler fit(NDArray data) {
        fit(data, new int[] {0});
        return this;
    }

    public NDArray transform(NDArray data) {
        if (fittedRange == null) {
            fit(data, new int[] {0});
        }
        NDArray std = data.sub(fittedMin).divi(fittedRange);
        return scale(std);
    }

    public NDArray transformi(NDArray data) {
        if (fittedRange == null) {
            fit(data, new int[] {0});
        }
        NDArray std = data.subi(fittedMin).divi(fittedRange);
        return scale(std);
    }

    private NDArray scale(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.muli(maxRange - minRange).addi(minRange);
        }
        return std;
    }

    private NDArray inverseScale(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.sub(minRange).divi(maxRange - minRange);
        }
        return std.duplicate();
    }

    private NDArray inverseScalei(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.subi(minRange).divi(maxRange - minRange);
        }
        return std;
    }

    public NDArray inverseTransform(NDArray data) {
        throwsIllegalStateWhenNotFitted();
        NDArray result = inverseScale(data);
        return result.muli(fittedRange).addi(fittedMin);
    }

    public NDArray inverseTransformi(NDArray data) {
        throwsIllegalStateWhenNotFitted();
        NDArray result = inverseScalei(data);
        return result.muli(fittedRange).addi(fittedMin);
    }

    private void throwsIllegalStateWhenNotFitted() {
        if (fittedRange == null) {
            throw new IllegalStateException("Min Max Scaler is not fitted");
        }
    }

    public MinMaxScaler detach() {
        detached = true;
        if (fittedMin != null) {
            fittedMin.detach();
        }
        if (fittedMax != null) {
            fittedMax.detach();
        }
        if (fittedRange != null) {
            fittedRange.detach();
        }
        return this;
    }

    public MinMaxScaler optRange(float minRange, float maxRange) {
        this.minRange = minRange;
        this.maxRange = maxRange;
        return this;
    }

    public NDArray getMin() {
        throwsIllegalStateWhenNotFitted();
        return fittedMin;
    }

    public NDArray getMax() {
        throwsIllegalStateWhenNotFitted();
        return fittedMax;
    }

    @Override
    public void close() {
        if (fittedMin != null) {
            fittedMin.close();
        }
        if (fittedMax != null) {
            fittedMax.close();
        }
        if (fittedRange != null) {
            fittedRange.close();
        }
    }
}

关注公众号,解锁更多深度学习内容。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/737296.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号