dataclass初探


初嘗

Python 3.7 引入了一個新的模塊,這個模塊就是今天要試探的 dataclass
dataclass 的用法和普通的類裝飾器沒有任何區別,它的作用是替換定義類的時候的:
def __init__()
我們來看看如何使用它

# 我們需要引入 dataclass 包
from dataclasses import dataclass 


@dataclass
class A:
    a: int
    b: int
    c: str
    d: str = "test"

a = A(1, 2, "3")
print(a)

 

我們執行這段代碼,得到結果
A(a=1, b=2, c='3', d='test')
可以看到,它的效果和

class A:
    def __init__(self, a, b, c, d="test"):
        self.a = a
        self.b = b
        self.c = c
        self.d = d
a = A(1, 2, "3")
print(a)

完全一樣!使用了 dataclass 可以省下很多代碼,可以幫我們節約很多時間,代碼也變得很簡潔了。

定義類型

我們發現,使用 dataclass 的時候,需要對初始化的參數進行類型定義,比如上面的例子里面,我為 abcd 定義的類型分別是 intintstr 和 str
那我建立實例的時候,傳遞非定義的類型的數據進去,會報錯么?
答案是很明顯的,是不會報錯的,畢竟 python 是解釋性語言嘛。
當然我們也要試試的

a = A("name", "age", 123, 123)
print(a)

得到結果
A(a='name', b='age', c=123, d=123)
果然是不會報錯的。
但是在 pycharm 之類的 IDE 里面,是會提醒修改的,這點很不爽


那么我們可以使用萬能的類型的么?當然是可以的,但是不建議(畢竟現在都建議寫 python 的工程師加上類型檢查了)
做法如下:

@dataclass
class A:
    a: ""
    b: 1

這樣就可以隨意傳參了。
我們只需要隨意給一個字符串就可以了,也可以事任何的其他類型

繼承

使用了 dataclass 之后,類的繼承還是之前的那樣么?
我們來試試

@dataclass
class A:
    a: int
    b: str


@dataclass
class B(A):
    c: int
    d: int

b = B(a=1, b="2", c=3, d=4)

就完了。
再來想想我們之前的繼承 __init__ 是怎么寫的

class A:
    def __init__(self, a: int, b: str):
        self.a = a
        self.b = b


class B(A):
    def __init__(self, a: int, b: str, c: int, d: int):
        super().__init__(a, b)
        self.c = c
        self.d = d

b = B(a=1, b="2", c=3, d=4)

一對比,是不是上面的代碼簡潔太多太多了!簡直的優化利器!

使用 make_dataclass 快速創建類

除此之外,dataclasses 還提供了一個方法 make_dataclass 讓我們可以快速創建類

from dataclasses import make_dataclass

A = make_dataclass(
    "A", 
    [("a", int), "b", ("c", str), ("d", int, 1)],
    namespace={'add_one': lambda self: self.a + 1})

這個和

@dataclass
class A:
    a: int
    b: ""
    c: str
    d: int = 1

    def add_one(self):
        self.a += 1

是完全一樣的

field

field 在 dataclasses 里面是比較重要的功能, 用於初處理定義的參數非常有用
在 PEP 557 中是這樣描述 field 的

Field objects describe each defined field. These objects are created internally, and are returned by the fields() module-level method (see below). Users should never instantiate a Field object directly.

大致意思就是 Field 對象是用於描述定義的字段的,這些對象是內部定義好了的。然后由 field() 方法返回,用戶不用直接實例化 Field。
我們先看看 field 是如何使用的

from dataclasses import dataclass, field


@dataclass
class A:
    a: str = field(default="123")

可以用於設立默認值,和 a: str = "123" 一個效果,那為什么我們還需要 field 呢?
因為 field 的功能遠不止這一個設置默認值,他還有很多有用的功能

  • 設置是否加載到 __init__ 里面去
@dataclass
class A:
    a: int
    b: int = field(default=10, init=False)
a = A(1) # 注意,實例化 A 的時候只需要一個參數,賦給 a 的

等價於:

class A:
    b = 10
    def __init__(self, a: int):
        self.a = a
  • 設置是否成為 __repr__ 返回參數
    我們在之前實例化 A 的時候,把實例化對象打印出來的話,是這樣的:
    A(a=1, b=10)
    那如果我們不想把特定的對象打印出來,可以這樣寫:
@dataclass
class A:
    a: int
    b: int = field(default=1, repr=False)

a = A(1)
print(a)

這時候,打印的結果為 A(a=1)

  • 設置是否計算 hash 的對象之一
    a: int = field(hash=False)
  • 設置是否成為和其他類進行對比的值之一
    a: int = field(compare=False)
  • 定義 field 信息
from dataclasses import field, dataclass, fields
@dataclass
class A:
    a: int = field(metadata={"name": "a"}) # metadata 需要接受一個映射對象,也就是 python 的字典

metadata = fields(A)
print(metadata)

打印的結果是
(Field(name='a',type=<class 'int'>,default=<dataclasses._MISSING_TYPE object at 0x10f2fe748>,default_factory=<dataclasses._MISSING_TYPE object at 0x10f2fe748>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'name': 'a'}),_field_type=_FIELD),)
是一個 tuple,第一個即是 a 字段的 field 定義
可以通過 metadata[0].metadata["name"] 獲取值

  • 自定義處理定義的參數
    有些字段需要我們進行一些預處理,不用傳遞初始值,由其他函數返回
    我們可以這么寫
def value():
    return "123"

@dataclass
class A:
    a: str = field(default_factory=value)

print(A().a) # 實例化 A 的時候已經可以不傳遞值了

打印的結果是 '123'

使用 dataclass 設定初始方法

使用裝飾器 dataclass 的時候,設定一些參數,即可選擇是否需要這些初始方法

  • __init__
@dataclass(init=False)
class A:
    a: int = 1

print(A())

打印結果
['__module__', '__annotations__', 'a', '__dict__', '__weakref__', '__doc__', '__dataclass_params__', '__dataclass_fields__', '__repr__', '__eq__', '__hash__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__ne__', '__gt__', '__ge__', '__init__', '__new__', '__reduce_ex__', '__reduce__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir__', '__class__']
的確是沒有 __init__ 的

  • __repr__
    field 可以設置哪個參數不加入類返回值,設置
    @dataclass(repr=False) 即可
  • __hash__
    設置是否需要對類進行 hash,可以結合 a: int = field(hash=True) 一起設置
  • __eq__
    這是類之間比較使用的方法,
    同樣可以結合 a: int = field(compare=True) 一起設置

源碼剖析

dataclasses 這個庫這么強大,我們來一步步剖析它的源碼吧

field 源碼剖析

首先我們看看 field 的源碼

def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
          hash=None, compare=True, metadata=None):
    if default is not MISSING and default_factory is not MISSING:
        raise ValueError('cannot specify both default and default_factory')
    return Field(default, default_factory, init, repr, hash, compare,
                 metadata)

這段代碼很簡單,對傳入的參數進行判斷之后,返回 Field 實例。
注意 default 和 default_factory 缺一不可,都是作為定義初始值的。
然后我們來看看 Field 的源碼:

class Field:
    __slots__ = ('name',
                 'type',
                 'default',
                 'default_factory',
                 'repr',
                 'hash',
                 'init',
                 'compare',
                 'metadata',
                 '_field_type',
                 )

    def __init__(self, default, default_factory, init, repr, hash, compare,
                 metadata):
        self.name = None
        self.type = None
        self.default = default
        self.default_factory = default_factory
        self.init = init
        self.repr = repr
        self.hash = hash
        self.compare = compare
        self.metadata = (_EMPTY_METADATA
                         if metadata is None or len(metadata) == 0 else
                         types.MappingProxyType(metadata))
        self._field_type = None

    def __repr__(self):
        return ('Field('
                f'name={self.name!r},'
                f'type={self.type!r},'
                f'default={self.default!r},'
                f'default_factory={self.default_factory!r},'
                f'init={self.init!r},'
                f'repr={self.repr!r},'
                f'hash={self.hash!r},'
                f'compare={self.compare!r},'
                f'metadata={self.metadata!r},'
                f'_field_type={self._field_type}'
                ')')

    def __set_name__(self, owner, name):
        func = getattr(type(self.default), '__set_name__', None)
        if func:
            # There is a __set_name__ method on the descriptor, call
            # it.
            func(self.default, owner, name)

基本沒有什么可以說的,就是簡單的類,功能也就一個 __set_name__
我們注意一下 __repr__ 里面的有個細節:
f'name={self.name!r},', 比如 self.name 為 "name", 這里會返回 "name='name',"

dataclass 源碼剖析

接下來我們來看 dataclass 的源碼

def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
              unsafe_hash=False, frozen=False):

    def wrap(cls):
        return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)

    if _cls is None:
        return wrap
    return wrap(_cls)

這是一個很常見的裝飾器
當我們定義類的時候,把類本身作為 _cls 參數傳遞進去,這時候返回一個 _process_class 函數的值
實例化類的時候,這時候 _cls 為 None, 返回 wrap 對象

接着我們來看 _process_class 源碼
這段代碼比較長,我們刪減部分(不影響核心功能),刪除的是生成初始化函數的部分,有興趣的讀者可以自己查看一下。

def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
    fields = {}

    setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
                                           unsafe_hash, frozen))
    any_frozen_base = False
    has_dataclass_bases = False
    for b in cls.__mro__[-1:0:-1]:
        base_fields = getattr(b, _FIELDS, None)
        if base_fields:
            has_dataclass_bases = True
            for f in base_fields.values():
                fields[f.name] = f
            if getattr(b, _PARAMS).frozen:
                any_frozen_base = True
    cls_annotations = cls.__dict__.get('__annotations__', {})
    cls_fields = [_get_field(cls, name, type)
                  for name, type in cls_annotations.items()]
    for f in cls_fields:
        fields[f.name] = f
        if isinstance(getattr(cls, f.name, None), Field):
            if f.default is MISSING:
                delattr(cls, f.name)
            else:
                setattr(cls, f.name, f.default)
    setattr(cls, _FIELDS, fields)

    if init:
        has_post_init = hasattr(cls, _POST_INIT_NAME)
        flds = [f for f in fields.values()
                if f._field_type in (_FIELD, _FIELD_INITVAR)]
        _set_new_attribute(cls, '__init__',
                           _init_fn(flds,
                                    frozen,
                                    has_post_init,
                                    '__dataclass_self__' if 'self' in fields
                                            else 'self',
                          ))

    return cls

 

這段代碼,最后將傳進來的 cls 返回出去,也就是返回的是類本身(初始化類的時候)

我們來看第一句代碼:

setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
                                           unsafe_hash, frozen))

_PARAMS 為前面定義的變量,值為 __dataclass_params__
_DataclassParams 是一個類
這句話就是把 _DataclassParams 實例作為值,__dataclass_params__ 作為屬性賦給 cls
所以,我們在查看定義的類的所有屬性的時候,會有一個 __dataclass_params__ 屬性,然后我們打印看看:
_DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
即是 _DataclassParams 實例

第二段代碼

fields = {}
any_frozen_base = False
has_dataclass_bases = False
for b in cls.__mro__[-1:0:-1]:
    base_fields = getattr(b, _FIELDS, None)
    if base_fields:
        has_dataclass_bases = True
        for f in base_fields.values():
            fields[f.name] = f
        if getattr(b, _PARAMS).frozen:
            any_frozen_base = True

 

前兩行都是定義變量,直接從第三行開始。
cls.__mro__[-1:0:-1] 這代表取 cls 本身和繼承的類,按照新式類的順序從子類到父類排序
(詳情見:mro
然后不要第一個(即自己本身),剩下的進行倒序排列,這時候,所有類的順序已經變成了父類到子類,這時候第一個為 object
_FIELDS 為前面定義的變量,為 __dataclass_fields__
輪詢排好序的類,如果由 __dataclass_fields__ 屬性,則進行前面的定義的變量操作,把所有的取到的值加入 fields
只有用 @dataclass 生成的類才會有這個屬性。

第三段代碼

cls_annotations = cls.__dict__.get('__annotations__', {})
cls_fields = [_get_field(cls, name, type)
                for name, type in cls_annotations.items()]
for f in cls_fields:
    fields[f.name] = f
    if isinstance(getattr(cls, f.name, None), Field):
        if f.default is MISSING:
            delattr(cls, f.name)
        else:
            setattr(cls, f.name, f.default)

cls_annotations = cls.__dict__.get('__annotations__', {})
這句話就是為了取出我們定義的所有字段
只要我們定義字段是

a: int
b: str

這樣的,就會自動有 __annotations__ 屬性
可以參看 PEP 526
然后賦予 cls 屬性操作
這步操作就是我們能夠進行類取值的關鍵

第四段代碼

setattr(cls, _FIELDS, fields) 

將 fields (最早定義的一個字典)作為值,賦給 cls 的屬性 __dataclass_fields__

第五段代碼

if init:
    has_post_init = hasattr(cls, _POST_INIT_NAME)
    flds = [f for f in fields.values()
            if f._field_type in (_FIELD, _FIELD_INITVAR)]
    _set_new_attribute(cls, '__init__',
                        _init_fn(flds,
                                frozen,
                                has_post_init,
                                '__dataclass_self__' if 'self' in fields
                                        else 'self',
                        ))

這段代碼表示,一旦設置 __init__=True,會在類里面加上這個方法。

def _set_new_attribute(cls, name, value):
    if name in cls.__dict__:
        return True
    setattr(cls, name, value)
    return False

_set_new_attribute 是一個為類賦予屬性的方法

至此,dataclass 源碼剖析完畢

make_dataclass 源碼剖析

源碼為:

def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
                   repr=True, eq=True, order=False, unsafe_hash=False,
                   frozen=False):

    if namespace is None:
        namespace = {}
    else:
        namespace = namespace.copy()

    seen = set()
    anns = {}
    for item in fields:
        if isinstance(item, str):
            name = item
            tp = 'typing.Any'
        elif len(item) == 2:
            name, tp, = item
        elif len(item) == 3:
            name, tp, spec = item
            namespace[name] = spec
        else:
            raise TypeError(f'Invalid field: {item!r}')

        if not isinstance(name, str) or not name.isidentifier():
            raise TypeError(f'Field names must be valid identifers: {name!r}')
        if keyword.iskeyword(name):
            raise TypeError(f'Field names must not be keywords: {name!r}')
        if name in seen:
            raise TypeError(f'Field name duplicated: {name!r}')

        seen.add(name)
        anns[name] = tp

    namespace['__annotations__'] = anns
    cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace))
    return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
                     unsafe_hash=unsafe_hash, frozen=frozen)

流程很詳細,就是解析我們定義的 fields,然后賦予 __annotations__屬性,最后使用 dataclass 生成一個類。
從其中的流程判斷來看,fields 里面最長只允許我們設置三個值,第一個名字,第二個類型,第三個是 fields 對象。
源碼剖析至此結束

尾聲

從功能上來看,dataclass 為我們帶來了比較好優化類方案,提供的各類方法也足夠用,可以在之后的項目里面逐漸使用起來。
從源碼上來看,源碼整體比較簡潔,使用了比較少見的 __annotations__,技巧足夠,代碼簡單易學。
建議新手可以從此入手,即可學習裝飾器也可學習優秀代碼。

 

轉自:https://zhuanlan.zhihu.com/p/60009941


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM