栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Tensorflow中的subclass

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Tensorflow中的subclass

Loss

以交叉熵为例:

tf的loss api一般通过如 tf.losses.SparseCategoricalCrossentropy(from_logits=True) 创建instance,或者不需要pass in参数的情况下:tf.losses.sparse_categorical_crossentropy, 然后通过loss(y_true, y_pred)来call这个函数

Subclass损失函数需要重写 __init__()和 call() 两个函数

  • __init__(self): accept parameters to pass during the call of your loss function
  • call(self, y_true, y_pred): use the targets (y_true) and the model predictions (y_pred) to compute the model's loss
class CustomMSE(keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor
Metrics

metric需要implement以下四个函数

  • __init__(self), in which you will create state variables for your metric.
  • update_state(self, y_true, y_pred, sample_weight=None), which uses the targets y_true and the model predictions y_pred to update the state variables.
  • result(self), which uses the state variables to compute the final results.
  • reset_states(self), which reinitializes the state of the metric.
add_loss and add_metric

you can call self.add_loss(loss_value) from inside the call method of a custom layer. Losses added in this way get added to the "main" loss during training (the one passed to compile())

class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(tf.reduce_sum(inputs) * 0.1)
        return inputs  # Pass-through layer.
class MetricLoggingLayer(layers.Layer):
    def call(self, inputs):
        # The `aggregation` argument defines
        # how to aggregate the per-batch values
        # over each epoch:
        # in this case we simply average them.
        self.add_metric(
            keras.backend.std(inputs), name="std_of_activation", aggregation="mean"
        )
        return inputs  # Pass-through layer.

In the Functional API, you can also call model.add_loss(loss_tensor), or model.add_metric(metric_tensor, name, aggregation).

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/286483.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号