2017年3月11日土曜日

python3 クロージャを使ったメモ化

Notebook

ここを参考にクロージャを使ったメモ化を練習する.
クロージャ機能とは,グローバルでない所で定義された関数が,その定義されたスコープにある変数の情報を覚えている機能のこと.元ネタではフィボナッチ列を計算しているので,今回はコラッツ予想を1からnまで示す関数を実装してみた.

In [1]:
def collatz_func():
    s = set([1])
    
    def collatz(n):
        sub_s = set()
        while True:
            if n in s:
                for e in sub_s: # 1
                    s.add(e)
                break
            else:
                sub_s.add(n)
                if n % 2 == 0:
                    n = int(n / 2)
                else:
                    n = 3 * n + 1
        return True
    
    return collatz

#1 では s = s.union(sub_s)としたかったが,これの理由で怒られる.
collatz_func()を実行すると,実際に計算を行う関数を返してくる.

In [2]:
collatz = collatz_func()
def col_range(n):
    for i in range(1, n+1):
        collatz(i)
In [3]:
def normal_col(n):
    while True:
        if n == 1:
            return True
        else:
            if n % 2 == 0:
                n = n / 2
            else:
                n = 3 * n + 1
                
def normal_col_range(n):
    for i in range(1, n+1):
        normal_col(i)

こっちは素朴に計算する関数.測度を比較しよう.

In [4]:
%time col_range(100000)
CPU times: user 148 ms, sys: 8 ms, total: 156 ms
Wall time: 156 ms

In [5]:
%time normal_col_range(100000)
CPU times: user 3.14 s, sys: 0 ns, total: 3.14 s
Wall time: 3.14 s

メモ化の効果は歴然だが,col_rangeにバグがあるだけかもしれないので,素朴に計算しつつ現れた数を適当なセットに加えていく関数を作って,collatzが作った検証済み自然数の集合と比較する.

In [6]:
s2 = set()
def normal_col2(n):
    while True:
        s2.add(n)
        if n == 1:
            return True
        else:
            if n % 2 == 0:
                n = int(n / 2)
            else:
                n = 3 * n + 1
                
def normal_col_range2(n):
    for i in range(1, n+1):
        normal_col2(i)
        
normal_col_range2(100000)
In [7]:
collatz.__closure__[0].cell_contents == s2 
Out[7]:
True

collatzはちゃんと計算を行っている. こちらによると,簡単にメモ化を実現するデコレータがあるようだ(デコレーターはまた今度勉強する).とりあえず試してみる.

In [8]:
import functools
@functools.lru_cache(maxsize=None)
def deco_collatz(n):
    for i in range(1, 1+n):
        while True:
            if i == 1:
                break
            if i % 2 == 0:
                i = int(i / 2)
            else:
                i = 3 * i + 1
In [9]:
%time deco_collatz(100000) 
CPU times: user 2.02 s, sys: 0 ns, total: 2.02 s
Wall time: 2.02 s

ちょっとの手間で高速化が実現できている.