前一段时间,经常把transformers包的源码拿出来看看,给我留下影响非常深刻的就是transformers包的设计模式。非常的优秀。
而实现这个设计模式,自然离不开python的class模块。类的继承起到了非常大的作用。
那么就类的继承这一部分,对类的继承做一做分享。
问题归类 问题1 两种super写法有什么差异么?
class Animal:
def __init__(self, name, age):
self.name = name
self.age = age
def __str__(self):
return f"from Animal --> name: {self.name}, age: {self.age}"
class Dog(Animal):
def __init__(self, name, age, dog_type):
# 看这里
# super(Dog, self).__init__(name, age) # way 1
super().__init__(name, age) # way 2
self.dog_type = dog_type
def __str__(self):
return f"from Dog --> name: {self.name}, age: {self.age}, type: {self.dog_type}"
if __name__ == '__main__':
dog = Dog(name='peter', age=12, dog_type='土狗')
print(dog)
上面的代码,表示了Dog这个类是继承自Animal的。
- 先说答案:使用way 1还是使用way 2都是没有区别的,不用纠结。两个是一样的。
- 在pycharm编辑器中,推荐使用way 1。因为他直接提醒的就是这个。
- way 2写起来比较简单,省事。
通常来说:子类继承父类这个规则大家都是知道的。很多时候就是搞不懂多重继承的时候,链路是什么样子的。
这里直接分享python继承的基本规则和原理:
原理:对于你定义的每一个类,Python 会计算出一个方法解析顺序(Method Resolution Order, MRO)列表,它代表了类继承的顺序,我们可以使用下面的方式获得某个类的 MRO 列表。
MRO列表确定是通过一个C3线性化算法来实现。具体算法这里不做解释。
规则- 子类永远在父类前面。
- 如果有多个父类,会根据它们在列表中的顺序被检查。
- 如果对下一个类存在两个合法的选择,选择第一个父类。
下面这部分代码,可以帮助你了解的更加深刻
class A:
def __init__(self):
print("into A")
print("leave A")
class B(A):
def __init__(self):
print("into B")
super(B, self).__init__()
print("leave B")
class C(A):
def __init__(self):
print("into C")
super(C, self).__init__()
print("leave C")
class E(C, B):
def __init__(self):
print("into E")
super(E, self).__init__()
print("leave E")
if __name__ == '__main__':
print(E.__mro__) # 这个就是mro链
e = E()
运行结果:
这个反映出类E的继承顺序。
(, , , , )
这里是反映整个程序的继承细节:(其实如果开启调试,然后逐步运行,也可以看到整个的程序处理流程)
into E into C into B into A leave A leave B leave C leave E问题2 父类的方法定义?
有时候,为了保证接口的统一性,会在父类定义好方法,但是却不实现它。就像是这样的:
可以看出来:PreTrainedModel的两个方法_init_weights都是没有实现的,只是定义了。
# 代码部分省略了
# 代码来源:transformers-master/src/transformers/modeling_utils.py
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods
for loading, downloading and saving models as well as a few methods common to all models to:
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
def _init_weights(self, module):
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weigths` is implemented for {self.__class__}")
但是在子类里面,都把_init_weights方法实现了。
# 代码部分省略了
# 代码来源: transformers-master/src/transformers/models/bert/modeling_bert.py
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
这种思想非常优秀,以至于影响我在写rust的时候,在设计代码结构的时候,要多多考虑考虑接口的统一,一个方法可以适用于不同的模型。
参考链接- 设计模式:https://zhuanlan.zhihu.com/p/31700225
- super的作用介绍:https://blog.csdn.net/wo198711203217/article/details/84097274



