skip to content
Site header image satoooh.org

Python で二項係数 nCr を高速に計算したい

AtCoder の問題を解いていると、 mod p の条件下で高速に二項係数 nCr を求める場面に多く遭遇するのでそのあたりの知識をまとめます。

Last Updated:

二項係数 nCr を計算する

まずは普通に計算してみましょう。これは二項係数を階乗を用いて表示した

nCr=n!r!×(nr)!_n \mathrm{C} _r = \frac{n!}{r! \times (n-r)!}

より、次のように求めることができます。

from math import factorial
print(factorial(n) // factorial(r) // factorial(n - r))

mod p の条件下で nCr を計算する

競プロの世界では nCr_n \mathrm{C} _r を計算する問題に、 109+710^9 + 7 で割ったあまりを求めさせるという制約が付いたものが多く存在します。

今回はこのような条件下で高速に nCr_n \mathrm{C} _r を計算するにはどうすればよいかを考えます。(きょうの本題)

この条件がつく時点で大きな n, r を扱うことが多く、高速に計算できるアルゴリズムを考えないと TLE 祭りになってしまうわけですね。

結論

結論から言うと、 nCr=(n!)×(r!)1×((nr)!)1_n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1} となることを利用して、次のように計算することで高速に処理を行うことが可能です。

def cmb(n, r, p):
    if (r < 0) or (n < r):
        return 0
    r = min(r, n - r)
    return fact[n] * factinv[r] * factinv[n-r] % p

p = 10 ** 9 + 7
N = 10 ** 6  # N は必要分だけ用意する
fact = [1, 1]  # fact[n] = (n! mod p)
factinv = [1, 1]  # factinv[n] = ((n!)^(-1) mod p)
inv = [0, 1]  # factinv 計算用
 
for i in range(2, N + 1):
    fact.append((fact[-1] * i) % p)
    inv.append((-inv[p % i] * (p // i)) % p)
    factinv.append((factinv[-1] * inv[-1]) % p)

print(cmb(n, r, p))

cmb は二項係数 nCr_n \mathrm{C} _r を求める関数を指すとします。(パラメータは場合によって変わります)組合せを意味する Combination の略です。

やっていることをカンタンに説明すると次のようになります。

  • n!,(n!)1n!, (n!)^{-1} について、 pp で割ったあまりを配列にまとめておく(下準備)
  • nCr=(n!)×(r!)1×((nr)!)1_n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1} に基づいて計算する

下準備について、 n!n!pp で割ったあまりを fact に、 (n!)1(n!)^{-1}pp で割ったあまりを factinv にそれぞれ格納しています。

こうすることで、掛け算自体は O(1)O(1) でできるので、実質階乗を求めるのに必要な O(n)O(n) の計算量で処理をすることができます。

mod p における n! の計算

まずは fact[n] = n! を p で割ったあまり の計算について見ていきましょう。これは 0!=10! =1 であることに注意して、次のように作成することができそうです。

fact = [1]  # 0!

for i in range(1, N + 1):
    fact.append((fact[-1] * i) % p)

mod p における n! の逆元の計算

次に factinv[n] = mod p における n! の逆元 (n!)^(-1) の計算を見てみます。これは、次のように分解して考えることで、先程やった n!n! の計算と同じように処理できそうです。

(n!)1=i=1ni1(n!)^{-1} = \prod_{i = 1}^{n} i^{-1}

つまり、mod p における aa の逆元 a1a^{-1} を求めてやれば、それを掛け算していくことで mod p における n!n! の逆元 (n!)1(n!)^{-1} も求められるわけです。

モジュラ逆数

この「逆元」というやつは モジュラ逆数 と呼ばれるやつで、次のような数になるようです。

モジュラ逆数は、与えられた整数 aa と法 pp に関して
a1x(modp)a^{-1} \equiv x {\pmod {p}}

という関係にある整数 xx をいう。

具体例を考えて理解を深めましょう。

💡
ex. 整数3の法11に関するモジュラ逆数xを求める

つまり、 31x(mod11)3^{-1} \equiv x {\pmod {11}} なる xx を計算するということになるが、これは次式を満たす xx を計算することである。

3x1(mod11) 3 x \equiv 1 \pmod{11} 

これは、3×4=121(mod11)3 \times 4 = 12 \equiv 1 \pmod{11} から x=4x = 4 と求まる。(一般には x=4+11kx = 4 + 11 k の形をしている)

このようにモジュラ逆数という考え方を用いて n1n^{-1} にも pp で割ったあまりのようなものを定義してあげることで、結局 nCr=(n!)×(r!)1×((nr)!)1_n \mathrm{C} _r = (n!) \times (r!)^{-1} \times ((n-r)!)^{-1} の mod p における計算は、それぞれの mod p における値を掛け算した値になることがわかります。

モジュラ逆数が満たす性質

具体的にモジュラ逆数を計算するために、次のような性質を利用します。

sstt で割った商を s//ts // t 、余りを s%ts \% t と表記する。
このとき mod p における aa の逆元 a1a^{-1} について、次の関係が成り立つ。
a1(p%a)1×(p//a)(modp)a^{-1} \equiv -(p \% a)^{-1} \times (p // a) \pmod{p}

証明は次のようになります。

証明

ppaa で割ると p=(p//a)×a+(p%a)p = (p // a) \times a + (p \% a) が成立。両辺の mod p を取って、

(p//a)×a+(p%a)0(p//a)+(p%a)×a10(p%a)×a1(p//a)a1(p%a)1×(p//a)(p // a) \times a + (p \% a) \equiv 0 \\ \Leftrightarrow (p // a) + (p \% a) \times a^{-1} \equiv 0 \\ \Leftrightarrow (p \% a) \times a^{-1} \equiv -(p // a) \\ \Leftrightarrow a^{-1} \equiv -(p \% a)^{-1} \times (p // a)

と変形できる。(証明終)

mod p における n の逆元の計算

さて、準備が整ったので inv[n] = mod p における n の逆元 (n)^(-1) から計算します。これは上の式から次のようにできます。

inv = [0, 1]  # 便宜上 inv[0] = 0 とした

for i in range(2, N + 1):
    inv.append((-inv[p % i] * (p // i)) % p)

再び、mod p における n! の逆元の計算

本題に戻ります。factinv[n] = mod p における n! の逆元 (n!)^(-1) の話をしているんでしたね、ここまできたら inv を掛け合わせていくだけなので次のようにできます。

factinv = [1, 1]

for i in range(2, N + 1):
    factinv.append((factinv[-1] * inv[i]) % p)

改めて、結論

以上をふまえてもう一度コードを見てみます。余計なコメントは消してあります。

def cmb(n, r, p):
    if (r < 0) or (n < r):
        return 0
    r = min(r, n - r)
    return fact[n] * factinv[r] * factinv[n-r] % p

p = 10 ** 9 + 7
N = 10 ** 6  # N は必要分だけ用意する
fact = [1, 1]
factinv = [1, 1]
inv = [0, 1]
 
for i in range(2, N + 1):
    fact.append((fact[-1] * i) % p)
    inv.append((-inv[p % i] * (p // i)) % p)
    factinv.append((factinv[-1] * inv[-1]) % p)

print(cmb(n, r, p))