Python 正確重載運算符


  有些事情讓我不安,比如運算符重載。我決定不支持運算符重載,這完全是個人選擇,因為我見過太多 C++ 程序員濫用它。

                                                ——James Gosling

                                                    Java 之父

 

  運算符重載的作用是讓用戶定義的對象使用中綴運算符(如 + 和 |)或一元運算符(如 - 和 ~)。說得寬泛一些,在 Python 中,函數調用(())、屬性訪問(.)和元素訪問 / 切片([])也是運算符。

  我們為 Vector 類簡略實現了幾個運算符。__add__ 和 __mul__ 方法是為了展示如何使用特殊方法重載運算符,不過有些小問題被我們忽視了。此外,我們定義的Vector2d.__eq__ 方法認為 Vector(3, 4) == [3, 4] 是真的(True),這可能並不合理。

 

運算符重載基礎

  在某些圈子中,運算符重載的名聲並不好。這個語言特性可能(已經)被濫用,讓程序員困惑,導致缺陷和意料之外的性能瓶頸。但是,如果使用得當,API 會變得好用,代碼會變得易於閱讀。Python 施加了一些限制,做好了靈活性、可用性和安全性方面的平衡:

  • 不能重載內置類型的運算符
  • 不能新建運算符,只能重載現有的
  • 某些運算符不能重載——is、and、or 和 not(不過位運算符
  • &、| 和 ~ 可以)

  前面的博文已經為 Vector 定義了一個中綴運算符,即 ==,這個運算符由__eq__ 方法支持。我們將改進 __eq__ 方法的實現,更好地處理不是Vector 實例的操作數。然而,在運算符重載方面,眾多比較運算符(==、!=、>、<、>=、<=)是特例,因此我們首先將在 Vector 中重載四個算術運算符:一元運算符 - 和 +,以及中綴運算符 + 和 *。

 

一元運算符 

  -(__neg__)

    一元取負算術運算符。如果 x 是 -2,那么 -x == 2。

  +(__pos__)

    一元取正算術運算符。通常,x == +x,但也有一些例外。如果好奇,請閱讀“x 和 +x 何時不相等”附注欄。

  ~(__invert__)

    對整數按位取反,定義為 ~x == -(x+1)。如果 x 是 2,那么 ~x== -3。

  支持一元運算符很簡單,只需實現相應的特殊方法。這些特殊方法只有一個參數,self。然后,使用符合所在類的邏輯實現。不過,要遵守運算符的一個基本規則:始終返回一個新對象。也就是說,不能修改self,要創建並返回合適類型的新實例。

  對 - 和 + 來說,結果可能是與 self 同屬一類的實例。多數時候,+ 最好返回 self 的副本。abs(...) 的結果應該是一個標量。但是對 ~ 來說,很難說什么結果是合理的,因為可能不是處理整數的位,例如在ORM 中,SQL WHERE 子句應該返回反集。

1  def __abs__(self):
2         return math.sqrt(sum(x * x for x in self))
3 
4     def __neg__(self):
5         return Vector(-x for x  in self)            #為了計算 -v,構建一個新 Vector 實例,把 self 的每個分量都取反
6 
7     def __pos__(self):
8         return Vector(self)                        #為了計算 +v,構建一個新 Vector 實例,傳入 self 的各個分量 

x 和 +x 何時不相等 

  每個人都覺得 x == +x,而且在 Python 中,幾乎所有情況下都是這樣。但是,我在標准庫中找到兩例 x != +x 的情況。

  第一例與 decimal.Decimal 類有關。如果 x 是 Decimal 實例,在算術運算的上下文中創建,然后在不同的上下文中計算 +x,那么 x!= +x。例如,x 所在的上下文使用某個精度,而計算 +x 時,精度變了,例如下面的 🌰

算術運算上下文的精度變化可能導致 x 不等於 +x

>>> import decimal
>>> ctx = decimal.getcontext()                  #獲取當前全局算術運算符的上下文引用
>>> ctx.prec = 40                          #把算術運算上下文的精度設為40
>>> one_third = decimal.Decimal('1') / decimal.Decimal('3') #使用當前精度計算1/3
>>> one_third
Decimal('0.3333333333333333333333333333333333333333')     #查看結果,小數點后的40個數字
>>> one_third == +one_third                    #one_third = +one_thied返回TRUE
True
>>> ctx.prec = 28                          #把精度降為28
>>> one_third == +one_third                    #one_third = +one_thied返回False
False
>>> +one_third Decimal('0.3333333333333333333333333333')   #查看+one_third,小術后的28位數字

雖然每個 +one_third 表達式都會使用 one_third 的值創建一個新 Decimal 實例,但是會使用當前算術運算上下文的精度。

  x != +x 的第二例在 collections.Counter 的文檔中(https://docs.python.org/3/library/collections.html#collections.Counter)。類實現了幾個算術運算符,例如中綴運算符 +,作用是把兩個Counter 實例的計數器加在一起。然而,從實用角度出發,Counter 相加時,負值和零值計數會從結果中剔除。而一元運算符 + 等同於加上一個空 Counter,因此它產生一個新的Counter 且僅保留大於零的計數器。

🌰  一元運算符 + 得到一個新 Counter 實例,但是沒有零值和負值計數器

>>> from collections import Counter
>>> ct = Counter('abracadabra')
>>> ct['r'] = -3
>>> ct['d'] = 0
>>> ct
Counter({'a': 5, 'r': -3, 'b': 2, 'c': 1, 'd': 0})
>>> +ct
Counter({'a': 5, 'b': 2, 'c': 1})

 

重載向量加法運算符+

  兩個歐幾里得向量加在一起得到的是一個新向量,它的各個分量是兩個向量中相應的分量之和。比如說:

>>> v1 = Vector([3, 4, 5])
>>> v2 = Vector([6, 7, 8])
>>> v1 + v2
Vector([9.0, 11.0, 13.0])
>>> v1 + v2 == Vector([3+6, 4+7, 5+8])
True

確定這些基本的要求之后,__add__ 方法的實現短小精悍,🌰 如下

    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)           #生成一個元祖,a來自self,b來自other,如果兩個長度不夠,通過fillvalue設置的補全值自動補全短的
        return Vector(a + b for a, b in pairs)                              #使用生成器表達式計算pairs中的各個元素的和

還可以把Vector 加到元組或任何生成數字的可迭代對象上:

1     # 在Vector類中定義    
2 
3     def __add__(self, other):
4         pairs = itertools.zip_longest(self, other, fillvalue=0.0)           #生成一個元祖,a來自self,b來自other,如果兩個長度不夠,通過fillvalue設置的補全值自動補全短的
5         return Vector(a + b for a, b in pairs)                              #使用生成器表達式計算pairs中的各個元素的和
6 
7     def __radd__(self, other):                                              #會直接委托給__add__
8         return self + other

  __radd__ 通常就這么簡單:直接調用適當的運算符,在這里就是委托__add__。任何可交換的運算符都能這么做。處理數字和向量時,+ 可以交換,但是拼接序列時不行。

 

重載標量乘法運算符* 

  Vector([1, 2, 3]) * x 是什么意思?如果 x 是數字,就是計算標量積(scalar product),結果是一個新 Vector 實例,各個分量都會乘以x——這也叫元素級乘法(elementwise multiplication)。

>>> v1 = Vector([1, 2, 3])
>>> v1 * 10
Vector([10.0, 20.0, 30.0])
>>> 11 * v1
Vector([11.0, 22.0, 33.0])

  涉及 Vector 操作數的積還有一種,叫兩個向量的點積(dotproduct);如果把一個向量看作 1×N 矩陣,把另一個向量看作 N×1 矩陣,那么就是矩陣乘法。NumPy 等庫目前的做法是,不重載這兩種意義的 *,只用 * 計算標量積。例如,在 NumPy 中,點積使用numpy.dot() 函數計算。

回到標量積的話題。我們依然先實現最簡可用的 __mul__ 和 __rmul__方法:

1     def __mul__(self, scalar):
2         if isinstance(scalar, numbers.Real):
3             return Vector(n * scalar for n in self)
4         else:
5             return NotImplemented
6 
7     def __rmul__(self, scalar):
8         return self * scalar

  這兩個方法確實可用,但是提供不兼容的操作數時會出問題。scalar參數的值要是數字,與浮點數相乘得到的積是另一個浮點數(因為Vector 類在內部使用浮點數數組)。因此,不能使用復數,但可以是int、bool(int 的子類),甚至 fractions.Fraction 實例等標量。

  提供了點積所需的 @ 記號(例如,a @ b 是 a 和 b 的點積)。@ 運算符由特殊方法 __matmul__、__rmatmul__ 和__imatmul__ 提供支持,名稱取自“matrix multiplication”(矩陣乘法)

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int' 

下面是相應特殊方法的代碼:

>>> va = Vector([1, 2, 3])
>>> vz = Vector([5, 6, 7])
>>> va @ vz == 38.0 # 1*5 + 2*6 + 3*7
True
>>> [10, 20, 30] @ vz
380.0
>>> va @ 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for @: 'Vector' and 'int'

 

眾多比較運算符

Python 解釋器對眾多比較運算符(==、!=、>、<、>=、<=)的處理與前文類似,不過在兩個方面有重大區別。

  • 正向和反向調用使用的是同一系列方法。例如,對 == 來說,正向和反向調用都是 __eq__ 方法,只是把參數對調了;而正向的 __gt__ 方法調用的是反向的 __lt__方法,並把參數對調。
  • 對 == 和 != 來說,如果反向調用失敗,Python 會比較對象的 ID,而不拋出 TypeError。

眾多比較運算符:正向方法返回NotImplemented的話,調用反向方法

 

分組

 

中綴運算符

 

正向方法調用

 

反向方法調用

 

后備機制

 

相等性

 

a == b

 

a.__eq__(b)

 

b.__eq__(a)

 

返回 id(a) == id(b)

 

 

a != b

 

a.__ne__(b)

 

b.__ne__(a)

 

返回 not (a == b)

 

排序

 

a > b

 

a.__gt__(b)

 

b.__lt__(a)

 

拋出 TypeError

 

 

a < b

 

a.__lt__(b)

 

b.__gt__(a)

 

拋出 TypeError

 

 

a >= b

 

a.__ge__(b)

 

b.__le__(a)

 

拋出 TypeError

 

 

a <= b

 

a.__le__(b)

 

b.__ge__(a)

 

拋出T ypeError

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

看下面的🌰

  1 from array import array
  2 import reprlib
  3 import math
  4 import numbers
  5 import functools
  6 import operator
  7 import itertools
  8 
  9 
 10 class Vector:
 11     typecode = 'd'
 12 
 13     def __init__(self, components):
 14         self._components = array(self.typecode, components)
 15 
 16     def __iter__(self):
 17         return iter(self._components)
 18 
 19     def __repr__(self):
 20         components = reprlib.repr(self._components)
 21         components = components[components.find('['):-1]
 22         return 'Vector({})'.format(components)
 23 
 24     def __str__(self):
 25         return str(tuple(self))
 26 
 27     def __bytes__(self):
 28         return (bytes([ord(self.typecode)]) + bytes(self._components))
 29 
 30     def __eq__(self, other):
 31         return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
 32 
 33     def __hash__(self):
 34         hashes = map(hash, self._components)
 35         return functools.reduce(operator.xor, hashes, 0)
 36 
 37     def __add__(self, other):
 38         pairs = itertools.zip_longest(self, other, fillvalue=0.0)           #生成一個元祖,a來自self,b來自other,如果兩個長度不夠,通過fillvalue設置的補全值自動補全短的
 39         return Vector(a + b for a, b in pairs)                              #使用生成器表達式計算pairs中的各個元素的和
 40 
 41     def __radd__(self, other):                                              #會直接委托給__add__
 42         return self + other
 43 
 44     def __mul__(self, scalar):
 45         if isinstance(scalar, numbers.Real):
 46             return Vector(n * scalar for n in self)
 47         else:
 48             return NotImplemented
 49 
 50     def __rmul__(self, scalar):
 51         return self * scalar
 52 
 53     def __matmul__(self, other):
 54         try:
 55             return sum(a * b for a, b in zip(self, other))
 56         except TypeError:
 57             return NotImplemented
 58 
 59     def __rmatmul__(self, other):
 60         return self @ other
 61 
 62     def __abs__(self):
 63         return math.sqrt(sum(x * x for x in self))
 64 
 65     def __neg__(self):
 66         return Vector(-x for x  in self)            #為了計算 -v,構建一個新 Vector 實例,把 self 的每個分量都取反
 67 
 68     def __pos__(self):
 69         return Vector(self)                         #為了計算 +v,構建一個新 Vector 實例,傳入 self 的各個分量
 70 
 71     def __bool__(self):
 72         return bool(abs(self))
 73 
 74     def __len__(self):
 75         return len(self._components)
 76 
 77     def __getitem__(self, index):
 78         cls = type(self)
 79 
 80         if isinstance(index, slice):
 81             return cls(self._components[index])
 82         elif isinstance(index, numbers.Integral):
 83             return self._components[index]
 84         else:
 85             msg = '{.__name__} indices must be integers'
 86             raise TypeError(msg.format(cls))
 87 
 88     shorcut_names = 'xyzt'
 89 
 90     def __getattr__(self, name):
 91         cls = type(self)
 92 
 93         if len(name) == 1:
 94             pos = cls.shorcut_names.find(name)
 95             if 0 <= pos < len(self._components):
 96                 return self._components[pos]
 97         msg = '{.__name__!r} object has no attribute {!r}'
 98         raise AttributeError(msg.format(cls, name))
 99 
100     def angle(self, n):
101         r = math.sqrt(sum(x * x for x in self[n:]))
102         a = math.atan2(r, self[n-1])
103         if (n == len(self) - 1 ) and (self[-1] < 0):
104             return math.pi * 2 - a
105         else:
106             return a
107 
108     def angles(self):
109         return (self.angle(n) for n in range(1, len(self)))
110 
111     def __format__(self, fmt_spec=''):
112         if fmt_spec.endswith('h'):
113             fmt_spec = fmt_spec[:-1]
114             coords = itertools.chain([abs(self)], self.angles())
115             outer_fmt = '<{}>'
116         else:
117             coords = self
118             outer_fmt = '({})'
119         components = (format(c, fmt_spec) for c in coords)
120         return outer_fmt.format(', '.join(components))
121 
122     @classmethod
123     def frombytes(cls, octets):
124         typecode = chr(octets[0])
125         memv = memoryview(octets[1:]).cast(typecode)
126         return cls(memv)
127 
128 va = Vector([1.0, 2.0, 3.0])
129 vb = Vector(range(1, 4))
130 print('va == vb:', va == vb)                 #兩個具有相同數值分量的 Vector 實例是相等的
131 t3 = (1, 2, 3)
132 print('va ==  t3:', va == t3)
133 
134 print('[1, 2] == (1, 2):', [1, 2] == (1, 2))

上面代碼執行返回的結果為:

va == vb: True
va ==  t3: True
[1, 2] == (1, 2): False

  從 Python 自身來找線索,我們發現 [1,2] == (1, 2) 的結果是False。因此,我們要保守一點,做些類型檢查。如果第二個操作數是Vector 實例(或者 Vector 子類的實例),那么就使用 __eq__ 方法的當前邏輯。否則,返回 NotImplemented,讓 Python 處理。

🌰 vector_v8.py:改進 Vector 類的 __eq__ 方法

1     def __eq__(self, other):
2         if isinstance(other, Vector):                                     #判斷對比的是否和Vector同屬一個實例
3             return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
4         else:
5             return NotImplemented                                         #否則,返回NotImplemented

改進以后的代碼執行結果:

>>> va = Vector([1.0, 2.0, 3.0])
>>> vb = Vector(range(1, 4))
>>> va == vb 
True
>>> t3 = (1, 2, 3)
>>> va == t3
False

 

增量賦值運算符 

  Vector 類已經支持增量賦值運算符 += 和 *= 了,示例如下

🌰  增量賦值不會修改不可變目標,而是新建實例,然后重新綁定

>>> v1 = Vector([1, 2, 3])
>>> v1_alias = v1             # 復制一份,供后面審查Vector([1, 2, 3])對象
>>> id(v1)                 # 記住一開始綁定給v1的Vector實例的ID
4302860128
>>> v1 += Vector([4, 5, 6])       # 增量加法運算
>>> v1                    # 結果與預期相符
Vector([5.0, 7.0, 9.0])
>>> id(v1)                 # 但是創建了新的Vector實例
4302859904
>>> v1_alias                # 審查v1_alias,確認原來的Vector實例沒被修改
Vector([1.0, 2.0, 3.0])
>>> v1 *= 11                # 增量乘法運算
>>> v1                   # 同樣,結果與預期相符,但是創建了新的Vector實例
Vector([55.0, 77.0, 99.0])
>>> id(v1)
4302858336

 

完整代碼:

  1 from array import array
  2 import reprlib
  3 import math
  4 import numbers
  5 import functools
  6 import operator
  7 import itertools
  8 
  9 
 10 class Vector:
 11     typecode = 'd'
 12 
 13     def __init__(self, components):
 14         self._components = array(self.typecode, components)
 15 
 16     def __iter__(self):
 17         return iter(self._components)
 18 
 19     def __repr__(self):
 20         components = reprlib.repr(self._components)
 21         components = components[components.find('['):-1]
 22         return 'Vector({})'.format(components)
 23 
 24     def __str__(self):
 25         return str(tuple(self))
 26 
 27     def __bytes__(self):
 28         return (bytes([ord(self.typecode)]) + bytes(self._components))
 29 
 30     def __eq__(self, other):
 31         if isinstance(other, Vector):                                     
 32             return (len(self) == len(other) and all(a == b for a, b in zip(self, other)))
 33         else:
 34             return NotImplemented                                     
 35 
 36     def __hash__(self):
 37         hashes = map(hash, self._components)
 38         return functools.reduce(operator.xor, hashes, 0)
 39 
 40     def __add__(self, other):
 41         pairs = itertools.zip_longest(self, other, fillvalue=0.0)           
 42         return Vector(a + b for a, b in pairs)                              
 43 
 44     def __radd__(self, other):                                             
 45         return self + other
 46 
 47     def __mul__(self, scalar):
 48         if isinstance(scalar, numbers.Real):
 49             return Vector(n * scalar for n in self)
 50         else:
 51             return NotImplemented
 52 
 53     def __rmul__(self, scalar):
 54         return self * scalar
 55 
 56     def __matmul__(self, other):
 57         try:
 58             return sum(a * b for a, b in zip(self, other))
 59         except TypeError:
 60             return NotImplemented
 61 
 62     def __rmatmul__(self, other):
 63         return self @ other
 64 
 65     def __abs__(self):
 66         return math.sqrt(sum(x * x for x in self))
 67 
 68     def __neg__(self):
 69         return Vector(-x for x  in self)           
 70 
 71     def __pos__(self):
 72         return Vector(self)                         
 73 
 74     def __bool__(self):
 75         return bool(abs(self))
 76 
 77     def __len__(self):
 78         return len(self._components)
 79 
 80     def __getitem__(self, index):
 81         cls = type(self)
 82 
 83         if isinstance(index, slice):
 84             return cls(self._components[index])
 85         elif isinstance(index, numbers.Integral):
 86             return self._components[index]
 87         else:
 88             msg = '{.__name__} indices must be integers'
 89             raise TypeError(msg.format(cls))
 90 
 91     shorcut_names = 'xyzt'
 92 
 93     def __getattr__(self, name):
 94         cls = type(self)
 95 
 96         if len(name) == 1:
 97             pos = cls.shorcut_names.find(name)
 98             if 0 <= pos < len(self._components):
 99                 return self._components[pos]
100         msg = '{.__name__!r} object has no attribute {!r}'
101         raise AttributeError(msg.format(cls, name))
102 
103     def angle(self, n):
104         r = math.sqrt(sum(x * x for x in self[n:]))
105         a = math.atan2(r, self[n-1])
106         if (n == len(self) - 1 ) and (self[-1] < 0):
107             return math.pi * 2 - a
108         else:
109             return a
110 
111     def angles(self):
112         return (self.angle(n) for n in range(1, len(self)))
113 
114     def __format__(self, fmt_spec=''):
115         if fmt_spec.endswith('h'):
116             fmt_spec = fmt_spec[:-1]
117             coords = itertools.chain([abs(self)], self.angles())
118             outer_fmt = '<{}>'
119         else:
120             coords = self
121             outer_fmt = '({})'
122         components = (format(c, fmt_spec) for c in coords)
123         return outer_fmt.format(', '.join(components))
124 
125     @classmethod
126     def frombytes(cls, octets):
127         typecode = chr(octets[0])
128         memv = memoryview(octets[1:]).cast(typecode)
129         return cls(memv)
View Code

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM