栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

从实战角度来看类的继承小细节

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

从实战角度来看类的继承小细节

背景

前一段时间,经常把transformers包的源码拿出来看看,给我留下影响非常深刻的就是transformers包的设计模式。非常的优秀。

而实现这个设计模式,自然离不开python的class模块。类的继承起到了非常大的作用。

那么就类的继承这一部分,对类的继承做一做分享。

问题归类 问题1 两种super写法有什么差异么?
class Animal:
    def __init__(self, name, age):
        self.name = name
        self.age = age

    def __str__(self):
        return f"from Animal --> name: {self.name}, age: {self.age}"


class Dog(Animal):
    def __init__(self, name, age, dog_type):
        # 看这里
        # super(Dog, self).__init__(name, age) # way 1
        super().__init__(name, age)  # way 2
        
        self.dog_type = dog_type

    def __str__(self):
        return f"from Dog --> name: {self.name}, age: {self.age}, type: {self.dog_type}"


if __name__ == '__main__':
    dog = Dog(name='peter', age=12, dog_type='土狗')
    print(dog)

上面的代码,表示了Dog这个类是继承自Animal的。

  1. 先说答案:使用way 1还是使用way 2都是没有区别的,不用纠结。两个是一样的。
  2. 在pycharm编辑器中,推荐使用way 1。因为他直接提醒的就是这个。
  3. way 2写起来比较简单,省事。
问题2 super的作用机制是什么?

通常来说:子类继承父类这个规则大家都是知道的。很多时候就是搞不懂多重继承的时候,链路是什么样子的。

这里直接分享python继承的基本规则和原理:

原理:

对于你定义的每一个类,Python 会计算出一个方法解析顺序(Method Resolution Order, MRO)列表,它代表了类继承的顺序,我们可以使用下面的方式获得某个类的 MRO 列表。

MRO列表确定是通过一个C3线性化算法来实现。具体算法这里不做解释。

规则
  1. 子类永远在父类前面。
  2. 如果有多个父类,会根据它们在列表中的顺序被检查。
  3. 如果对下一个类存在两个合法的选择,选择第一个父类。

下面这部分代码,可以帮助你了解的更加深刻

class A:
    def __init__(self):
        print("into A")
        print("leave A")


class B(A):
    def __init__(self):
        print("into B")
        super(B, self).__init__()
        print("leave B")


class C(A):
    def __init__(self):
        print("into C")
        super(C, self).__init__()
        print("leave C")


class E(C, B):
    def __init__(self):
        print("into E")
        super(E, self).__init__()
        print("leave E")


if __name__ == '__main__':
    print(E.__mro__) # 这个就是mro链
    e = E()

运行结果:

这个反映出类E的继承顺序。

(, , , , )

这里是反映整个程序的继承细节:(其实如果开启调试,然后逐步运行,也可以看到整个的程序处理流程)

into E
into C
into B
into A
leave A
leave B
leave C
leave E
问题2 父类的方法定义?

有时候,为了保证接口的统一性,会在父类定义好方法,但是却不实现它。就像是这样的:

可以看出来:PreTrainedModel的两个方法_init_weights都是没有实现的,只是定义了。

# 代码部分省略了
# 代码来源:transformers-master/src/transformers/modeling_utils.py

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
    r"""
    Base class for all models.

    [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods
    for loading, downloading and saving models as well as a few methods common to all models to:
    """
    config_class = None
    base_model_prefix = ""
    main_input_name = "input_ids"

    def _init_weights(self, module):
        """
        Initialize the weights. This method should be overridden by derived class.
        """
        raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")


但是在子类里面,都把_init_weights方法实现了。

# 代码部分省略了
# 代码来源: transformers-master/src/transformers/models/bert/modeling_bert.py

class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
    supports_gradient_checkpointing = True
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

这种思想非常优秀,以至于影响我在写rust的时候,在设计代码结构的时候,要多多考虑考虑接口的统一,一个方法可以适用于不同的模型。

参考链接
  1. 设计模式:https://zhuanlan.zhihu.com/p/31700225
  2. super的作用介绍:https://blog.csdn.net/wo198711203217/article/details/84097274
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/840556.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号