前言

  最初认识这个库是因为 dayu_widgets 里面用到了这个进行函数的重载。
  对于它能实现的效果还是挺感兴趣的。

  singledispatch 可以实现函数的泛型重载
  可以使用 pip install singledispatch 安装使用, github 地址
  下面是官方提供的案例整合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from singledispatch import singledispatch
from decimal import Decimal

@singledispatch
def fun(arg, verbose=False):
if verbose:
print("Let me just say,", end=" ")
print(arg)

@fun.register(int)
def _(arg, verbose=False):
if verbose:
print("Strength in numbers, eh?", end=" ")
print(arg)

@fun.register(list)
def _(arg, verbose=False):
if verbose:
print("Enumerate this:")
for i, elem in enumerate(arg):
print(i, elem)

def nothing(arg, verbose=False):
print("Nothing.")

fun.register(type(None), nothing)

@fun.register(float)
@fun.register(Decimal)
def fun_num(arg, verbose=False):
if verbose:
print("Half of your number:", end=" ")
print(arg / 2)

fun("Hello, world.")
# Hello, world.
fun("test.", verbose=True)
# Let me just say, test.
fun(42, verbose=True)
# Strength in numbers, eh? 42
fun(['spam', 'spam', 'eggs', 'spam'], verbose=True)
# Enumerate this:
# 0 spam
# 1 spam
# 2 eggs
# 3 spam
fun(None)
# Nothing.
fun(1.23)
# 0.615

  可以看到 singledispatch 根据第一个参数的类型调用不同函数的功能。
  这个功能在 Python 3.4 之后引入到 functools 里面。

原理分析

singledispatch 装饰器拆解

  首先看装饰器 @singledispatch 的作用
  装饰器整体代码可以简化为如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from functools import update_wrapper

def singledispatch(func):

registry = {}

def dispatch(cls):
...


def register(cls, func=None):
...

def wrapper(*args, **kw):
return dispatch(args[0].__class__)(*args, **kw)

registry[object] = func
wrapper.register = register
wrapper.dispatch = dispatch
# update_wrapper 等价于 @wraps
update_wrapper(wrapper, func)
return wrapper

  @singledispatch 包装函数之后返回 wrapper 对象
  wrapper 同时添加 register & dispatch 方法
  背后调用其实是将第一个参数的类型放到 dispatch 函数进行 调用分发。

  因为 wrapper.register 有这个设置。
  如果我们把 装饰器 的语法糖拆除就很清楚到底发生了什么

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# @singledispatch
# def fun(*args, **kwargs):
# pass
# 等价于下面的写法 ↓↓↓

from singledispatch import singledispatch

def fun(*args, **kwargs):
pass
fun = singledispatch(fun)

# --------------------------------

print(fun.register)
# <function singledispatch.<locals>.register at 0x0000027A6D0DE670>

  经过 装饰器 封装之后 fun 就会有 wrapper.register 方法了。
  那只要 register 方法也是一个装饰器的写法,就可以继续沿用 @ 装饰器的写法。
  从这也可以理解为啥后续 register 的函数命名已经不重要了。

regsiter 函数分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from functools import update_wrapper

def get_cache_token():
return ABCMeta._abc_invalidation_counter

def singledispatch(func):

registry = {}
# dispatch_cache = WeakKeyDictionary()
def ns(): pass
ns.cache_token = None

def dispatch(cls):
...

def register(cls, func=None):
if func is None:
return lambda f: register(cls, f)
registry[cls] = func
if ns.cache_token is None and hasattr(cls, '__abstractmethods__'):
ns.cache_token = get_cache_token()
# dispatch_cache.clear()
return func

def wrapper(*args, **kw):
...

# 省略 ...

  register 使用了非常取巧的方式构建带参数的装饰器。
  如果 func 没有传参就返回一个 lambda 来接收参数
  然后会将当前获取的类型存放到 registry 里面
  get_cache_token 是获取 ABCMeta._abc_invalidation_counter 的计数
  因为使用了 ABCMeta 元类会影响到 mro 判断,这里可以先抛开不提。

  register 函数主要是给 registry 字典添加 类型 对应 func 的匹配关系

dispatch 函数分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from functools import update_wrapper
from weakref import WeakKeyDictionary

def singledispatch(func):

registry = {}
# NOTE 使用 WeakKeyDictionary
dispatch_cache = WeakKeyDictionary()
# def ns(): pass
# ns.cache_token = None

def dispatch(cls):

NOTE ABCMeta 数量发生变化重置 缓存
if ns.cache_token is not None:
current_token = get_cache_token()
if ns.cache_token != current_token:
dispatch_cache.clear()
ns.cache_token = current_token

try:
# NOTE 从 cache 取值
impl = dispatch_cache[cls]
except KeyError:
try:
# NOTE 从 registry 取值
impl = registry[cls]
except KeyError:
# NOTE 没有匹配的类型,可能是用户扩展的类型,查找 mro 找到最匹配的 方法。
impl = _find_impl(cls, registry)
# NOTE 存放到缓存里面
dispatch_cache[cls] = impl
return impl

def register(cls, func=None):
# 省略 ...
# NOTE 重置缓存
dispatch_cache.clear()
# 省略 ...

def wrapper(*args, **kw):
...

# 省略 ...

  WeakKeyDictionary 的用法可以参考这篇文章 链接
  大概就是如果 key 的对象已经不存在的话,那么 WeakKeyDictionary 会自动清理这个键值对

  为什么这里需要用到 WeakKeyDictionary ,因为 dispatch 传递的 cls 可能是用户扩展的类型。
  用户也有可能处于某些原因直接删除了这个 cls 类型导致缓存出问题,所以用 WeakKeyDictionary 这种骚操作可以轻松解决问题。

  应用场景如下 ↓↓↓

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from singledispatch import singledispatch

class TestDict(dict):
def __init__(self,data):
super(TestDict, self).__init__(data)

def some_method(self):
print("some_method")


@singledispatch
def fun(*args, **kwargs):
print('original',args)

@fun.register(dict)
def _(data):
for key, val in data.items():
print(key,val)

a = {'a':1}

fun(a)
# a 1
fun(TestDict(a))
# a 1

fun(['a'])
# original (['a'],)
fun({'as'})
# original ({'as'},)

  在这种自定义扩展的情况下,需要根据 TestDict 类型的 mro 匹配出最符合条件的 registry 类型,调用相关的方法。
  _find_impl 就是来干这个事情的。

_find_impl 剖析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

def _compose_mro(cls, types):
"""Calculates the method resolution order for a given class *cls*.

Includes relevant abstract base classes (with their respective bases) from
the *types* iterable. Uses a modified C3 linearization algorithm.

"""
bases = set(cls.__mro__)
# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (typ not in bases and hasattr(typ, '__mro__')
and issubclass(cls, typ))
types = [n for n in types if is_related(n)]
# Remove entries which are strict bases of other entries (they will end up
# in the MRO anyway.
def is_strict_base(typ):
for other in types:
if typ != other and typ in other.__mro__:
return True
return False
types = [n for n in types if not is_strict_base(n)]
# Subclasses of the ABCs in *types* which are also implemented by
# *cls* can be used to stabilize ABC ordering.
type_set = set(types)
mro = []
for typ in types:
found = []
for sub in typ.__subclasses__():
if sub not in bases and issubclass(cls, sub):
found.append([s for s in sub.__mro__ if s in type_set])
if not found:
mro.append(typ)
continue
# Favor subclasses with the biggest number of useful bases
found.sort(key=len, reverse=True)
for sub in found:
for subcls in sub:
if subcls not in mro:
mro.append(subcls)
return _c3_mro(cls, abcs=mro)

def _find_impl(cls, registry):
mro = _compose_mro(cls, registry.keys())
match = None
for t in mro:
if match is not None:
# If *match* is an implicit ABC but there is another unrelated,
# equally matching implicit ABC, refuse the temptation to guess.
if (t in registry and t not in cls.__mro__
and match not in cls.__mro__
and not issubclass(match, t)):
raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
match, t))
break
if t in registry:
match = t
return registry.get(match)

  _c3_mro 会根据就是根据 _compose_mro 过滤的信息重新计算一遍 mro 顺序
  mro 计算采用了 c3 算法,具体的计算过程可以参考 链接 Python官网

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
下面的伪代码比较好的阐明了 C3 算法的作用
L[object] = [object]
L[D] = [D, object]
L[E] = [E, object]
L[F] = [F, object]
L[B] = [B, D, E, object]
L[C] = [C, D, F, object]

L[A] = [A] + merge(L[B], L[C], [B], [C])
= [A] + merge([B, D, E, object], [C, D, F, object], [B], [C])
= [A, B] + merge([D, E, object], [C, D, F, object], [C])
= [A, B, C] + merge([D, E, object], [D, F, object])
= [A, B, C, D] + merge([E, object], [F, object])
= [A, B, C, D, E] + merge([object], [F, object])
= [A, B, C, D, E, F] + merge([object], [object])
= [A, B, C, D, E, F, object]

merge 会去取各个数组里面排第一位后面没有重复且排到第一位以外的对象

  不过这里的 singledispatch mro 还考虑 ABCMeta 元类的影响。
  所以需要构建一个特殊的 c3 算法进行处理。
  也直接导致 singledispatch 复杂了很多。

  因为大多数情况下,很少会用到 Python 的 ABCMeta 进行编程。
  所以 _compose_mro 大都是返回了 python __mro__ 的顺序
  然后再从 mro 继承顺序里找出最匹配 registry 存储对象的函数进行调用。

总结

  singledispatch 用 Python 实现 c3 算法还挺有意思的,我这里就没有详细列出来了。
  建议去看源码学习。