详解Python的迭代器、生成器以及相关的itertools包

时间:2022-09-14 13:45:12

对数学家来说,Python这门语言有着很多吸引他们的地方。举几个例子:对于tuple、lists以及sets等容器的支持,使用与传统数学类似的符号标记方式,还有列表推导式这样与数学中集合推导式和集的结构式(set-builder notation)很相似的语法结构。

另外一些很吸引数学爱好者的特性是Python中的iterator(迭代器)、generator(生成器)以及相关的itertools包。这些工具帮助人们能够很轻松的写出处理诸如无穷序列(infinite sequence)、随机过程(stochastic processes)、递推关系(recurrence relations)以及组合结构(combinatorial structures)等数学对象的优雅代码。本文将涵盖我关于迭代器和生成器的一些笔记,并且有一些我在学习过程中积累的相关经验。
Iterators

迭代器(Iterator)是一个可以对集合进行迭代访问的对象。通过这种方式不需要将集合全部载入内存中,也正因如此,这种集合元素几乎可以是无限的。你可以在Python官方文档的“迭代器类型(Iterator Type)”部分找到相关文档。

让我们对定义的描述再准确些,如果一个对象定义了__iter__方法,并且此方法需要返回一个迭代器,那么这个对象就是可迭代的(iterable)。而迭代器是指实现了__iter__以及next(在Python 3中为__next__)两个方法的对象,前者返回一个迭代器对象,而后者返回迭代过程的下一个集合元素。据我所知,迭代器总是在__iter__方法中简单的返回自己(self),因为它们正是自己的迭代器。

一般来说,你应该避免直接调用__iter__以及next方法。而应该使用for或是列表推导式(list comprehension),这样的话Python能够自动为你调用这两个方法。如果你需要手动调用它们,请使用Python的内建函数iter以及next,并且把目标迭代器对象或是集合对象当做参数传递给它们。举个例子,如果c是一个可迭代对象,那么你可以使用iter(c)来访问,而不是c.__iter__(),类似的,如果a是一个迭代器对象,那么请使用next(a)而不是a.next()来访问下一个元素。与之相类似的还有len的用法。

说到len,值得注意的是对迭代器而言没必要去纠结length的定义。所以它们通常不会去实现__len__方法。如果你需要计算容器的长度,那么必须得手动计算,或者使用sum。本文末,在itertools模块之后会给出一个例子。

有一些可迭代对象并不是迭代器,而是使用其他对象作为迭代器。举个例子,list对象是一个可迭代对象,但并不是一个迭代器(它实现了__iter__但并未实现next)。通过下面的例子你可以看到list是如何使用迭代器listiterator的。同时值得注意的是list很好地定义了length属性,而listiterator却没有。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
>>> a = [1, 2]
>>> type(a)
<type 'list'>
>>> type(iter(a))
<type 'listiterator'>
>>> it = iter(a)
>>> next(it)
1
>>> next(it)
2
>>> next(it)
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
StopIteration
>>> len(a)
2
>>> len(it)
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
TypeError: object of type 'listiterator' has no len()

当迭代结束却仍然被继续迭代访问时,Python解释器会抛出StopIteration异常。然而,前述中提到迭代器可以迭代一个无穷集合,所以对于这种迭代器就必须由用户负责确保不会造成无限循环的情况,请看下面的例子:
 

?
1
2
3
4
5
6
7
8
9
10
class count_iterator(object):
  n = 0
 
  def __iter__(self):
    return self
 
  def next(self):
    y = self.n
    self.n += 1
    return y

下面是例子,注意最后一行试图将一个迭代器对象转为list,这将导致一个无限循环,因为这种迭代器对象将不会停止。
 

?
1
2
3
4
5
6
7
8
9
10
>>> counter = count_iterator()
>>> next(counter)
0
>>> next(counter)
1
>>> next(counter)
2
>>> next(counter)
3
>>> list(counter) # This will result in an infinite loop!

最后,我们将修改以上的程序:如果一个对象没有__iter__方法但定义了__getitem__方法,那么这个对象仍然是可迭代的。在这种情况下,当Python的内建函数iter将会返回一个对应此对象的迭代器类型,并使用__getitem__方法遍历list的所有元素。如果StopIteration或IndexError异常被抛出,则迭代停止。让我们看看以下的例子:
 

?
1
2
3
4
5
6
class SimpleList(object):
  def __init__(self, *items):
    self.items = items
 
  def __getitem__(self, i):
    return self.items[i]

用法在此:
 

?
1
2
3
4
5
6
7
8
9
10
11
12
>>> a = SimpleList(1, 2, 3)
>>> it = iter(a)
>>> next(it)
1
>>> next(it)
2
>>> next(it)
3
>>> next(it)
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
StopIteration

现在来看一个更有趣的例子:根据初始条件使用迭代器生成Hofstadter Q序列。Hofstadter在他的著作《G?del, Escher, Bach: An Eternal Golden Braid》中首次提到了这个嵌套的序列,并且自那时候开始关于证明这个序列对所有n都成立的问题就开始了。以下的代码使用一个迭代器来生成给定n的Hofstadter序列,定义如下:

?
1
Q(n)=Q(n-Q(n-1))+Q(n?Q(n?2))

给定一个初始条件,举个例子,qsequence([1, 1])将会生成H序列。我们使用StopIteration异常来指示序列不能够继续生成了,因为需要一个合法的下标索引来生成下一个元素。例如如果初始条件是[1,2],那么序列生成将立即停止。
 

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class qsequence(object):
  def __init__(self, s):
    self.s = s[:]
 
  def next(self):
    try:
      q = self.s[-self.s[-1]] + self.s[-self.s[-2]]
      self.s.append(q)
      return q
    except IndexError:
      raise StopIteration()
 
  def __iter__(self):
    return self
 
  def current_state(self):
    return self.s

用法在此:
 

?
1
2
3
4
5
6
7
>>> Q = qsequence([1, 1])
>>> next(Q)
2
>>> next(Q)
3
>>> [next(Q) for __ in xrange(10)]
[3, 4, 5, 5, 6, 6, 6, 8, 8, 8]
Generators

生成器(Generator)是一种用更简单的函数表达式定义的生成器。说的更具体一些,在生成器内部会用到yield表达式。生成器不会使用return返回值,而当需要时使用yield表达式返回结果。Python的内在机制能够帮助记住当前生成器的上下文,也就是当前的控制流和局部变量的值等。每次生成器被调用都适用yield返回迭代过程中的下一个值。__iter__方法是默认实现的,意味着任何能够使用迭代器的地方都能够使用生成器。下面这个例子实现的功能同上面迭代器的例子一样,不过代码更紧凑,可读性更强。
 

?
1
2
3
4
5
def count_generator():
  n = 0
  while True:
   yield n
   n += 1

来看看用法:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
>>> counter = count_generator()
>>> counter
<generator object count_generator at 0x106bf1aa0>
>>> next(counter)
0
>>> next(counter)
1
>>> iter(counter)
<generator object count_generator at 0x106bf1aa0>
>>> iter(counter) is counter
True
>>> type(counter)
<type 'generator'>

现在让我们尝试用生成器来实现Hofstadter's Q队列。这个实现很简单,不过我们却不能实现前的类似于current_state那样的函数了。因为据我所知,不可能在外部直接访问生成器内部的变量状态,因此如current_state这样的函数就不可能实现了(虽然有诸如gi_frame.f_locals这样的数据结构可以做到,但是这毕竟是CPython的特殊实现,并不是这门语言的标准部分,所以并不推荐使用)。如果需要访问内部变量,一个可能的方法是通过yield返回所有的结果,我会把这个问题留作练习。
 

?
1
2
3
4
5
6
7
8
9
def hofstadter_generator(s):
  a = s[:]
  while True:
    try:
      q = a[-a[-1]] + a[-a[-2]]
      a.append(q)
      yield q
    except IndexError:
      return

请注意,在生成器迭代过程的结尾有一个简单的return语句,但并没有返回任何数据。从内部来说,这将抛出一个StopIteration异常。

下一个例子来自Groupon的面试题。在这里我们首先使用两个生成器来实现Bernoulli过程,这个过程是一个随机布尔值的无限序列,True的概率是p而False的概率为q=1-p。随后实现一个von Neumann extractor,它从Bernoulli process获取输入(0<p<1),并且返回另一个Bernoulli process(p=0.5)。
 

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import random
 
def bernoulli_process(p):
  if p > 1.0 or p < 0.0:
    raise ValueError("p should be between 0.0 and 1.0.")
 
  while True:
    yield random.random() < p
 
def von_neumann_extractor(process):
  while True:
    x, y = process.next(), process.next()
    if x != y:
      yield x

最后,生成器是一种生成随机动态系统的很有利的工具。下面这个例子将演示著名的帐篷映射(tent map)动态系统是如何通过生成器实现的。(插句题外话,看看数值的不准确性是如何开始关联变化并呈指数式增长的,这是一个如帐篷映射这样的动态系统的关键特征)。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
>>> def tent_map(mu, x0):
...  x = x0
...  while True:
...    yield x
...    x = mu * min(x, 1.0 - x)
...
>>>
>>> t = tent_map(2.0, 0.1)
>>> for __ in xrange(30):
...  print t.next()
...
0.1
0.2
0.4
0.8
0.4
0.8
0.4
0.8
0.4
0.8
0.4
0.8
0.4
0.8
0.4
0.8
0.4
0.799999999999
0.400000000001
0.800000000003
0.399999999994
0.799999999988
0.400000000023
0.800000000047
0.399999999907
0.799999999814
0.400000000373
0.800000000745
0.39999999851
0.79999999702

另一个相似的例子是Collatz序列。
 

?
1
2
3
4
5
def collatz(n):
  yield n
  while n != 1:
   n = n / 2 if n % 2 == 0 else 3 * n + 1
   yield n

请注意在这个例子中,我们仍旧没有手动抛出StopIteration异常,因为它会在控制流到达函数结尾的时候自动抛出。

请看用法:
 

?
1
2
3
4
5
6
7
8
9
10
>>> # If the Collatz conjecture is true then list(collatz(n)) for any n will
... # always terminate (though your machine might run out of memory first!)
>>> list(collatz(7))
[7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]
>>> list(collatz(13))
[13, 40, 20, 10, 5, 16, 8, 4, 2, 1]
>>> list(collatz(17))
[17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]
>>> list(collatz(19))
[19, 58, 29, 88, 44, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]
Recursive Generators

生成器可以像其它函数那样递归。让我们来看一个自实现的简单版本的itertools.permutations,这个生成器通过给定一个item列表生成其全排列(在实际中请使用itertools.permutations,那个实现更快)。基本思想很简单:对列表中的每一个元素,我们通过将它与列表第一个元素交换将其放置到第一的位置上去,而后重新递归排列列表的剩余部分。
 

?
1
2
3
4
5
6
7
8
9
def permutations(items):
  if len(items) == 0:
    yield []
  else:
    pi = items[:]
    for i in xrange(len(pi)):
      pi[0], pi[i] = pi[i], pi[0]
      for p in permutations(pi[1:]):
        yield [pi[0]] + p

 

?
1
2
3
4
5
6
7
8
9
>>> for p in permutations([1, 2, 3]):
...   print p
...
[1, 2, 3]
[1, 3, 2]
[2, 1, 3]
[2, 3, 1]
[3, 1, 2]
[3, 2, 1]
Generator Expressions

生成器表达式可以让你通过一个简单的,单行声明定义生成器。这跟Python中的列表推导式非常类似,举个例子,下面的代码将定义一个生成器迭代所有的完全平方。注意生成器表达式的返回结果是一个生成器类型对象,它实现了next和__iter__两个方法。
 

?
1
2
3
4
5
6
7
8
9
10
11
12
13
>>> g = (x ** 2 for x in itertools.count(1))
>>> g
<generator object <genexpr> at 0x1029a5fa0>
>>> next(g)
1
>>> next(g)
4
>>> iter(g)
<generator object <genexpr> at 0x1029a5fa0>
>>> iter(g) is g
True
>>> [g.next() for __ in xrange(10)]
[9, 16, 25, 36, 49, 64, 81, 100, 121, 144]

同样可以使用生成器表达式实现Bernoulli过程,在这个例子中p=0.4。如果一个生成器表达式需要另一个迭代器作为循环指示器,并且这个生辰器表达式使用在无限序列上的,那么itertools.count将是一个很好的选择。若非如此,xrange将是一个不错的选择。
 

?
1
2
3
>>> g = (random.random() < 0.4 for __ in itertools.count())
>>> [g.next() for __ in xrange(10)]
[False, False, False, True, True, False, True, False, False, True]

正如前面提到的,生成器表达式能够用在任何需要迭代器作为参数的地方。举个例子,我们可以通过如下代码计算前十个全平方数的累加和:
 

?
1
2
>>> sum(x ** 2 for x in xrange(10))
285

更多生成器表达式的例子将在下一节给出。
itertools模块

itertools模块提供了一系列迭代器能够帮助用户轻松地使用排列、组合、笛卡尔积或其他组合结构。

在开始下面的部分之前,注意到上面给出的所有代码都是未经优化的,在这里只是充当一个示例的作用。在实际使用中,你应该避免自己去实现排列组合除非你能够有更好的想法,因为枚举的数量可是按照指数级增加的。

让我们先从一些有趣的用例开始。第一个例子来看如何写一个常用的模式:循环遍历一个三维数组的所有下标元素,并且循环遍历满足0≤i<j<k≤n条件的所有下标。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from itertools import combinations, product
 
n = 4
d = 3
 
def visit(*indices):
  print indices
 
# Loop through all possible indices of a 3-D array
for i in xrange(n):
  for j in xrange(n):
    for k in xrange(n):
      visit(i, j, k)
 
# Equivalent using itertools.product
for indices in product(*([xrange(n)] * d)):
  visit(*indices)
 
# Now loop through all indices 0 <= i < j < k <= n
for i in xrange(n):
  for j in xrange(i + 1, n):
    for k in xrange(j + 1, n):
      visit(i, j, k)
 
# And equivalent using itertools.combinations
for indices in combinations(xrange(n), d):
  visit(*indices)

使用itertools模块提供的枚举器有两个好处:代码能够在单行内完成,并且很容易扩展到更高维度。我并未比较for方法和itertools两种方法的性能,也许跟n有很大关系。如果你想的话请自行测试评判。

第二个例子,来做一些有趣的数学题:使用生成器表达式、itertools.combinations以及itertools.permutations来计算排列的逆序数,并且计算一个列表全排列逆序数之和。如OEIS A001809所示,求和的结果趋近于n!n(n-1)/4。在实际使用中直接通过这公式计算要比上面的代码更高效,不过我写这个例子是为了练习itertools枚举器的使用。
 

?
1
2
3
4
5
6
7
8
9
10
import itertools
import math
 
def inversion_number(A):
  """Return the number of inversions in list A."""
  return sum(1 for x, y in itertools.combinations(xrange(len(A)), 2) if A[x] > A[y])
 
def total_inversions(n):
  """Return total number of inversions in permutations of n."""
  return sum(inversion_number(A) for A in itertools.permutations(xrange(n)))

用法如下:
 

?
1
2
3
4
5
>>> [total_inversions(n) for n in xrange(10)]
[0, 0, 1, 9, 72, 600, 5400, 52920, 564480, 6531840]
 
>>> [math.factorial(n) * n * (n - 1) / 4 for n in xrange(10)]
[0, 0, 1, 9, 72, 600, 5400, 52920, 564480, 6531840]

第三个例子,通过brute-force counting方法计算recontres number。recontres number的定义在此。首先,我们写了一个函数在一个求和过程中使用生成器表达式去计算排列中fixed points出现的个数。然后在求和中使用itertools.permutations和其他生成器表达式计算包含n个数并且有k个fixed points的排列的总数。然后得到结果。当然了,这个实现方法是效率低下的,不提倡在实际应用中使用。再次重申,这只是为了掩饰生成器表达式以及itertools相关函数使用方法的示例。
 

?
1
2
3
4
5
6
7
def count_fixed_points(p):
  """Return the number of fixed points of p as a permutation."""
  return sum(1 for x in p if p[x] == x)
 
def count_partial_derangements(n, k):
  """Returns the number of permutations of n with k fixed points."""
  return sum(1 for p in itertools.permutations(xrange(n)) if count_fixed_points(p) == k)

用法:
 

?
1
2
3
# Usage:
>>> [count_partial_derangements(6, i) for i in xrange(7)]
[265, 264, 135, 40, 15, 0, 1]