もし生物情報科学専攻の大学院生が "StableDiffusion" を理解しようとしたら 5 ~NCSN~

前回

cake-by-the-river.hatenablog.jp


お久しぶりです。今回から肝心の拡散モデルを見ていこうと思います。今回は、スコアベースモデルを用いた Noise-Conditional Score Networks (NCSN) を解説し、拡散モデルの本題であり次回扱う Denoising Diffusion Probabilistic Models (DDPM) の前座となる要素の解説をしようと思います。


なお、今回以降の内容は、最近発売された拡散モデルに関する入門書の内容を参考にしました。

スコアベースモデル

そもそも画像を生成するAIというのはどのように設計すればよいのでしょうか?StableDiffusionなどが流行る前は、GANと呼ばれるモデルが注目されており、これらは生成モデルと呼ばれていました。

生成モデルとは、与えられたデータが為す空間を学習し、与えられたデータと類似したデータを生成するAIモデルです。画像の例を基に考えます。

デジタルの画像とは、それぞれのピクセルがRGB(A)のチャネルごとに数値をもった大きな数値の塊でした。例えば各ピクセルの値を一列に並べたベクトルを考えることにすると、縦横dピクセルの3チャネルの画像は大きさ d×d×3 のベクトルであり、各要素は 0 - 255 の 256 種類の整数値を取り得るものだと考えられます。これら数値は画像ごとに異なり、それぞれを d×d×3 の巨大な次元を持つ空間上の一点に対応させることが出来ます。

機械学習の考え方では、この巨大空間上の各点(すなわち各画像)に確率の数値を振り分けます。この確率は、私たちがAIに学習させたい画像の特徴を捉えたものとします。例えば、青空を生成してほしければ、青のチャネルに対応する数値が高く、赤や緑の数値が低い画像(点)が対応するため、その領域の確率は高いものとなるはずです。より高度には、人間の形をした画像の特徴や、イラストのように周囲の色が同一になりやすい画像の特徴などに合わせ、何かしらの形式で確率が決まります。

この確率の"山々"に応じて画像を生成することにします。そうすることで、私たちが思い描いた画像を高確率で生成できるわけです。確率を用いているのは、決定的な画像の生成では多様性がなくなってしまうから、とも捉えられます。

私たちが最終的に目指したいのは、例えば「青空の下咲き乱れるひまわりの畑の写真」といった条件をつけた画像を生成することですが、しばらくは、条件のない画像の単なる生成について考えることとします。

スコアに着目する

さて、画像の為す空間上の確率分布を上手く推定さえできれば、確率に従って生成すればよいですが、この確率を単にそのまま推定することは、非常に難易度が高いと言えます。私たちは確率分布自体の形状は知らないため、思い描いている画像の集合(学習データ)をAIに与えることしかできません。その場合、画像の空間上のいくらかの点を与えたことになり、その密度(どの辺に点が集まっているか)によって確率を疑似的に学ぶことになります。このような場合、機械学習で使われてきた代表的な方法は最尤推定です。よくある例としては、データ \mathbf{x} のなす確率分布 q_{\theta}(\mathbf{x}) がエネルギー関数によって表現された(統計力学の)カノニカル分布の形式で表せるとし、学習データからこのエネルギー関数  f_\theta(\mathbf{x}) を推定するものです。

 \displaystyle q_{\theta}(\mathbf{x}) = \mathrm{exp}(-f_\theta(\mathbf{x})) / Z(\theta)

最尤推定ではデータの尤度を最大化しますが、上の関数の尤度は、

 \displaystyle L(\theta) = -\frac{1}{N}\sum_{i=1}^N [f_\theta(\mathbf{x}^{(i)})] - \mathrm{log} Z(\theta)

となります。ここで、第二項の分配関数  Zすべての画像の確率を必要とする数値であるため、(尤度を最大化するために必要な)尤度の勾配を求めることは困難です。仮に第二項を考慮しないとすると、学習していないデータの領域の数値も不用意に下がってしまい、予期せぬ画像を生成する可能性が高くなるため、やはりダメです。

一方、分配関数を計算せずとも確率分布のサンプリングを可能にする方法に、MCMCマルコフ連鎖モンテカルロ法)があります。詳細はこの記事の追記に書きますが、MCMCでは、ブラウン運動のように空間上を右往左往する際、その動きを実際に行うかどうか、現在地と予定地の尤度の比で判定(採択)します。先ほどの  q_\theta の式を考えると、分配関数Zの項は打ち消しあうため、エネルギー関数のみで判定でき、得られたランダム性のある動き(確率過程)から分布を予測することが出来ます。


しかし、MCMCは局所的な山にとらわれやすく、また山の間を行き来する動きはエネルギーが高く採択されにくいため、効率的に確率分布の全体の形状を知ることが出来ません。そこで、今解きたい問題を少し変形することでこの問題の解決を図ります。


確率分布が滑らか(微分可能)であるという仮定(カノニカル分布ではエネルギー関数が滑らかという仮定)を置くと、対数尤度  \mathrm{log}\ p(\mathbf{x}) の入力  \mathbf{x} に対する勾配が計算できます。この値は スコア と呼ばれています。

 \displaystyle \mathbf{s}(\mathbf{x}) = \nabla_{\mathbf{x}} \mathrm{log}\ p(\mathbf{x}) = \frac{\nabla_{\mathbf{x}} p(\mathbf{x})}{p(\mathbf{x})}

対数の微分を考えると、スコアは確率分布の山の傾きをその地点の高さで割った値であるため、山の頂上に近い部分では十分小さくなります。また、カノニカル分布でのスコアは、分配関数  Z が入力  \mathbf{x} に依存しないため、分配関数の勾配の項が 0 となり、エネルギー関数の勾配に等しくなります。

 \displaystyle \nabla_{\mathbf{x}} q_{\theta}(\mathbf{x}) = - \nabla_{\mathbf{x}} f_\theta (\mathbf{x}) - - \nabla_{\mathbf{x}} \mathrm{log}\ Z(\theta) = - \nabla_{\mathbf{x}} f_\theta (\mathbf{x})


このスコアを使ってもとの確率分布を効率的にサンプリングするには、ランジュバン・モンテカルロ法を用いることができます。この方法は、MCMCの一種ではありますが、スコア、すなわちエネルギーの勾配が下がる方向を基本として、さらにノイズも乗せた動きを行います。すなわち、毎回の遷移が以下の式で表されます。

 \mathbf{x}_k = \mathbf{x}_{k-1} + \alpha \nabla_{\mathbf{x}} \mathrm{log}\ p(\mathbf{x}_{k-1}) + \sqrt{2 \alpha} \mathbf{u}_k, \ \ \mathbf{u}_k \sim \mathcal{N}(0, \mathbf{I})

なおこの式は、毎回の微小な動きが、エネルギーの下がる方向へのドリフト(第1項)にブラウン運動(ウィーナー過程, 第2項)を追加した確率過程に対応するものと言えます。

 d\mathbf{X}_t = -\nabla_{\mathbf{x}} E(\mathbf{X}_t) dt + \sqrt{2} d\mathbf{W}_t

このアルゴリズムは、通常のMCMCとは異なり動きを必ず毎回行う(採択率100%)ため効率的で、また局所的な構造にとらわれにくいという特徴があります。なぜこのアルゴリズムで(極限的に)元の確率分布からのサンプリングが可能なのか、なぜ効率が良いかなどについては、記事の最後に追記した章にて解説しています。

デノイジングスコアマッチング

このように確率分布ではなくスコアを学習し、それを利用して生成を行うモデルはスコアベースモデルと呼ばれています。私たちの次の問題は、どのようにスコアを推定するか、です。


一番に思いつくのは、スコア  \nabla_{\mathbf{x}} \mathrm{log}\ p(\mathbf{x}) とそれを学習するモデル  \mathbf{s}_\theta (\mathbf{x}) の差をL2ノルムで抑えるような目的関数を導入すること(明示的スコアマッチング, ESM)です。しかし、学習データは画像(空間上の点)の集合でしかなく、確率分布のスコアは未知なので、実際に明示的に学習することはできません。

 \displaystyle J_{ESM_p}(\theta) = \frac{1}{2} \mathbb{E}_{p(\mathbf{x})} [ \| \nabla_{\mathbf{x}} \mathrm{log}\ p(\mathbf{x}) - \mathbf{s}_\theta(\mathbf{x}) \|^2 ]

これを解決する方法は色々と提案されていますが、大事なポイントは、この明示的スコアマッチングと数学的に等価で訓練データのみからでも学習可能なアルゴリズムを考えることです。ここでは、NCSNの基本となっているデノイジングスコアマッチングという手法を紹介します。


現在、確率分布自体は未知で、そこから得られたサンプルだけが利用できる状況ですが、このような場合に元の確率分布を(曲がりになりにも)推測する方法として、カーネル密度推定と呼ばれる手法があります。また、より身近なものとしては、ヒストグラムが挙げられます。例えば、ある集団の身長の数値をたくさん取得したとして、なんとなくそれら身長の分布の形状を知りたい場合、まず私たちはヒストグラムを可視化すると思います。カーネル密度推定は、雑に言えば、このヒストグラムをより滑らかに(それっぽく)得られる手法です。

ヒストグラムは、データが各区間に何個存在するかをカウントし、その数を高さに反映させた長方形の棒で分布の形を近似したものですが、何も長方形の棒で近似する必要はないはずです。たとえばガウス分布のような滑らかな分布で近似すると、それらを組み合わせた全体の分布も滑らかで、よりそれっぽい見た目になると言えます。具体的には、各データ点の値を平均  \mu とし、特定の分散  \sigma(分布の横幅に対応)をもつガウス分布(下図緑線)を単に足し合わせます。下の図のように、ヒストグラムと比べてよりそれっぽいものが出来上がることが分かります。

カーネル密度推定

ここで、カーネル密度推定を用い元の確率分布をある程度似せた分布について、その傾きやスコアは実は解析的に求めることが可能であることに注意します。なぜならば、この分布が各ガウス分布の足し合わせに過ぎず、それらガウス分布上の任意の点でのスコアは以下のように求まるからです。

 \displaystyle \nabla_{\mathbf{x}} \mathrm{log}\ \frac{1}{\sqrt{2\pi\sigma^2}} \mathrm{exp} (-\frac{1}{2\sigma^2}\|x - \mu\|^2) = 
 \nabla_{\mathbf{x}} (-\frac{1}{2\sigma^2}\|x - \mu\|^2) = -\frac{1}{\sigma^2}(x - \mu)

一つ目の等号では定数倍が無視できることを用いています。こうして、ガウス分布によるカーネル密度推定をした分布でのスコアは、訓練データのみから得ることが可能であることが分かります。


さて、肝心の分布のスコアの学習方法については、未だ曖昧なままとなっています。例えば、学習自体は確率分布をそのまま扱うことは出来ず、何かしら離散的な数値計算に基づいて行う必要があります。

現在利用できそうなのは、各データ点を平均としたガウス分布でのスコアの値のみであり、その値がそのガウス分布の分散と平均からの差(高次元ならばベクトル)で与えられています。このことを逆手に取ります。すなわち、元の訓練データ点  \mathbf{x} に分散  \sigmaガウスノイズを加えた点  \mathbf{\tilde{x}} を新たに生成します。この生成過程は

 \mathbf{\tilde{x}} = \mathbf{x} + \epsilon,\ \ \epsilon \sim \mathcal{N}(\mathbf{0}, \sigma^2 \mathbf{I})

として表現できます。このとき、ノイズの乗ったデータ  \mathbf{\tilde{x}} からノイズを除去(デノイジング)し、元のデータ  \mathbf{x} を求める問題を考えます。これは加えられたノイズ  \epsilon を予測する問題であり、分散の値で正規化してあげることで、以下の形式の目的関数を最小化する問題に帰着します。

 \displaystyle J_{DSM_{p_\sigma}}(\theta) = \frac{1}{2} \mathbb{E}_{\epsilon \sim \mathcal{N}(\mathbf{0}, \sigma^2 \mathbf{I}),\mathbf{x} \sim p(\mathbf{x})} [ \| - \frac{1}{\sigma^2}\epsilon - \mathbf{s}_\theta(\mathbf{x} + \epsilon, \sigma) \|^2 ]

モデルとなる関数  \mathbf{s}_\theta(\mathbf{x} + \epsilon, \sigma) は(あえて) \sigma も引数としておくことで、あらゆる分散のノイズを加えた場合でも推定できるような一般性を獲得させています。そしてこの目的関数は、実は明示的スコアマッチングの目的関数と定数倍の差しかとらないことが示せます。

 \displaystyle J_{ESM_p}(\theta) = J_{DSM_{p_\sigma}}(\theta) + C

詳しい証明は、参考に上げた本や、以下の記事を参考にしてもらいたいです。お気持ちとしては、 J_{DSM_{p_\sigma}}(\theta) でフィッティングしたい値は、議論を巻き戻せばカーネル密度推定の分布での(1つのガウス分布での)スコアに対応しており、全訓練データの場合の重み付き平均を考えることで、近似した分布でのスコアの推定と等価になるといった流れになると思います。

www.beam2d.net


デノイジングスコアマッチングでは、摂動を加えた後の分布に関する推定を行うことになるため、元の分布の正確な推定にはなりません。しかし、実際に学習する際には訓練データのみを扱うため、元の分布を推定しようと思っても経験分布、すなわちデルタ関数の混合分布として推定するのが一番となってしまい、いわゆる過学習が起こります。摂動、すなわちガウス分布により”滑らかにする”ことは、こうした過学習を抑制することにもつながります。

NCSN

ここまでの内容から、デノイジングスコアマッチングにより推定されたスコア関数を用い、ランジュバン・モンテカルロ法を用いて画像を生成することで問題が解決するように見えます。しかし、この方法でもいくつかの問題点が残ります。これらの問題点を解決するように設計されたのが、Noise-Conditional Score Networks(NCSN)です。

arxiv.org


まず、データ点があまり存在しない部分のスコア関数の推定が不正確となる問題があります。これは、世の中の画像データが、高次元の空間上の、ずっと低次元の部分空間(多様体)上にしか存在していないという多様体仮説と合わさって深刻な問題となります。最初の例で言えば、d×d×3という高次元の空間のすべてが画像となりうるものの、実際に写真やイラストなどの"自然な"画像は、ずっと少ないパラメータ(次元)で記述することが出来るというものです。

このように空間の多くの点で確率がほとんど0となっている場合、その場所のスコアは未学習となります。しかし、実際に画像を生成する際の初期状態は、例えば(テレビの砂嵐のような)ランダムに色を充てたものから始めることが多く、その多くでスコアが不正確であるため、肝心の"自然な"画像のなす低次元の多様体上に遷移することが困難となります。

確率がほぼ0の値を取る領域でのスコアの推定が不正確となる例。左が実際の分布とそのスコア(矢印からなるベクトル場)で、右が学習により推定されたもの。論文より引用。

さらに、ランジュバン・モンテカルロ法は一般のMCMCと比べて山を越えやすくはなったものの、多峰性の確率分布の各モード(山)間の行き来はやはりレアであり、どうしても多くのステップを必要とします。このことで、短いステップ数では各山に過剰に滞在し実際の密度とは異なるサンプリングをもたらすことが考えられます。さらに、確率の山の間を形成する確率が低い領域ではスコアも不正確となっているため、そもそもあらぬ方向へと動いてしまい、異なる山への移動が出来ないことも考えられます。

こうした問題に対し、NCSNではデノイジングスコアマッチングのモデルスコア関数がノイズの分散  \sigma を含むことを利用します。上記の問題は、確率分布が空間の多くの点で0となることや、多峰性が原因でした。一方、デノイジングスコアマッチングでは元の分布に摂動を加えた(カーネル密度推定的)分布の推定を行います。このとき、各ガウス分布の分散  \sigma を非常に大きくとれば、裾が非常に広くなることで空間のほとんどで分布の値は非ゼロとなり、また分布の形状も単峰性に変わりやすくなります。このような分布であれば、ランジュバン・モンテカルロ法でも問題なくサンプリングが可能です。

σを大きくすることで分布が単峰になる様子

NCSNでは、ノイズのレベル  \sigma を複数用意します。最初は大きな  \sigma でサンプリングを行い、徐々に小さな  \sigma でサンプリングするようにすることで、多峰性の分布でも上手に探索することが出来るようになります。例えば、最初は元の確率分布が分からなくなり全体が大きなガウス分布に見えるほどの  \sigma にてランジュバンの拡散過程を実施、数千ステップかけて小刻みに  \sigma を小さくしつつ拡散を実施したのち、最終的に  \sigma \sim 0 にて最後のサンプルを取得することで、元の確率分布を高精度に模倣したサンプリングが可能となります。

Annealed Langevin Dynamics により正しくサンプリングが可能となることを示す図。左が元の分布に基づくサンプル。真ん中は純粋なランジュバン・モンテカルロ法。右が今回の手法。純粋な方法では山の密度を誤って推定しているが、今回の手法では正しく推定できている様子が分かる。論文より引用。


今回は、拡散モデルの基本となるランジュバン・モンテカルロ法やデノイジングという概念によるスコアベースの分布推定について扱いました。次回は NCSN と数学的に等価でありつつ、Stable Diffusion の根本原理に通ずる DDPM を解説したいと思います。


cake-by-the-river.hatenablog.jp


ランジュバン・モンテカルロ法の解説

随時書き進めます。

メトロポリス・ヘイスティング法
ランジュバン・モンテカルロ法