もし生物情報科学専攻の大学院生が "StableDiffusion" を理解しようとしたら 7 ~DDPM~
前回
cake-by-the-river.hatenablog.jp
今回は、拡散モデルの最も重要な論文である Denoising Diffusion Probablistic Models (DDPM)を解説します。
前回の潜在変数モデルとしての側面(AutoEncoder)も持ちつつ、スコアベースモデルによる画像生成であるNCSNと数学的に等価で、より学習の効率が良いアルゴリズムであるDDPMは、Stable Diffusion(Latent Diffusion Model)の中枢を担っているため、ここが理解できればStable Diffusionをはじめとした拡散モデルの概略を抑えることが出来ると言えます。今回も、最近発売された拡散モデルの本も参考にしています。なお、今回は数式が(大量にあった前回をさらに超えるほど)沢山出ますが、出来るだけ"お気持ち"を重視して解説するよう頑張ります。
DDPM
スコアベースモデルであるNCSNでは、異なる幅のガウスノイズを加えた分布上でのスコアを求めることで、複雑な分布の学習とサンプリングを可能としていました。一方、拡散モデル(Diffusion Probabilistic Models)では徐々に入力に対しノイズが強くなるような過程を連続して行うようにします。入力データに徐々にノイズの影響が乗ってくるため、最終的にはただのガウスノイズと見分けがつかなくなるまで変形していきます。
拡散モデルの元ネタは、非平衡統計力学という物理学の分野での話題にあります。拡散モデルと非平衡統計力学との関係について、詳しくは末尾に追記しようと思いますが、導入で軽く触れようと思います。
マルコフ過程による拡散と逆拡散
拡散モデルを考えるために、NCSNでの複雑な分布への適応の過程について、異なる視点から考えることにします。NCSNでは、最初は分散の非常に大きなノイズで近似された分布でのサンプリングを行い、少しずつノイズを抑えていき、最終的にノイズがほぼ0の分布に対するサンプリングに帰着させていました。各段階でのサンプリングはその段階での分布を満遍なく動く粒子のように扱っており、最初は単なるガウス分布の上に動く粒子だったものが、最終的に(画像の為す)複雑な分布の上を動く粒子へと時間で変遷したとも考えられます。
したがって、NCSNは単純なガウス分布から複雑な分布への変遷を考えていると言えます。このような分布間の変遷を考えることは、非平衡統計力学と呼ばれる分野でよく研究されてきました。非平衡統計力学では、熱力学における平衡状態から非平衡状態へと変遷する過程、あるいはその逆の過程を、確率過程に基づいて考えます。例えばペットボトルのお茶をシャバシャバと振った後、机の上に静止させておいた時のお茶の状態の変化を考える、といった場合、振った直後は非平衡の(乱流を含むような)複雑な状態ですが、最終的には風が吹かない湖のような安定した状態≒熱平衡状態へと変遷します。
ここでは、拡散過程と逆拡散過程というものを考えます。拡散過程は複雑な分布に従う粒子が単純な分布に従う粒子へと拡散により変遷する過程を、逆拡散過程はその反対の動きをする過程(時間反転に対応)を指します。
更にここでは、これらの過程に(単純)マルコフ過程と呼ばれる性質を仮定します。マルコフ過程は、次の時点での状態が現在の状態のみに依存し、過去の状態には依らないというもので、MCMCでも用いられています。したがって、離散化した時刻 に対し、拡散過程 や逆拡散過程 は以下のような形式で書けることになります。
画像生成では、適当な入力(ガウス分布など単純な分布に従う画像)から目的の画像を生成するため、これは逆拡散過程に対応し、学習では目的画像を出力する確率(尤度) を最大化する最尤推定が必要となります。今回は、逆拡散過程内に含まれる潜在変数 をすべて周辺化することで尤度が求まります。
とはいえ、画像自体が従う複雑な分布の情報は未知であるため、いくら単純な分布からスタートする逆拡散過程であっても、上記の積分計算は求めることが出来ません。しかし、今考えている逆拡散過程が拡散過程の時間反転に対応し、拡散過程の初期値 は観測変数で既知であることを踏まえると、次の式変形により拡散過程に関する積分へと書き換えることが出来ます。
3行目ではマルコフ過程に基づく展開を行っています。最終的に各時刻での二方向の遷移の確率の比の計算が現れるところがポイントとなります。なぜこの式変形を行おうと思ったのか、そもそもなぜこのようなモデルを考えることになったのか、については非平衡統計力学との関係に関する追記をご覧ください。
拡散モデルの変分推論
さて、対数尤度を用いて最尤推定を行おうと思います。上の式変形では入力の分布の情報を入れていなかった(入力を固定したときの値)ので、実際に最尤推定で用いる目的関数は入力に関する周辺化が必要となります(通常の学習ではサンプルのアンサンブル平均などで近似している)。
結局、logの中の積分項を実際に計算することが難しいという問題にすぐ直面してしまいました。しかし、私たちはこの状況を解決する方法をすでに前回学んでいます。拡散モデルでは拡散過程と逆拡散過程の二つのプロセスを用いていますが、これらはまとめて潜在変数モデルとみなせ、拡散過程はEncoder, 逆拡散過程はDecoderに対応するAutoEncoderモデルの形式をしていると言えます。となると、対数尤度に関する問題は変分推論法を用いて解けるはずです。つまり、ELBOの最大化問題に置き換えることが出来ます。
このままでもよい気がしますが、分数になっている部分の計算を から の形式へと変形することを考えます。これは、実際に学習する際のことを考えた処理になります。この指針を取る理由は後ほど明らかとなります。
実はあともう一押し必要です。それは、あえて を にすることです。なぜならば、後者は初期値が決まった(条件つき)拡散過程であるのに対し、前者はありとあらゆる拡散過程に関する平均、つまり初期値による周辺化がないと求められない確率だからです。
一方、後者に関しては、解析的に解ける可能性が高い計算になっています。とくに、元の分布に同じ形式のノイズを加えた後の分布も解析的に求まる場合(ガウス分布など。この性質は分布の再生性などと呼ばれている)について、について繰り返しノイズを加える計算を行ったあとの も、解析的な形で求められます。
このことから、条件付き拡散過程における事後分布 も、ベイズの定理を用いて次のように解析的に求められます。
なお、このあたりの計算については、後ほどガウス分布の場合に具体的に求めることとなります。
さて、ELBO を条件付き拡散過程 の形式に書き換えましょう。今回は で条件づけているため、その部分の計算はややこしくなるため切り離します(edge effect)。
第3項について、まず はマルコフ過程なので を条件に加えても何も確率が変化しないこと、およびベイズの定理から目標の となるように式変形します。
後半の畳み込み級数部分が大量に相殺しあうことで、シンプルになりました。これを元の目的関数に戻すと
さて、各項ごとに期待値計算を分離すると、各項内で登場する変数以外の積分は消すことが出来ます。
第1項は に関する二つの拡散過程での負のKLダイバージェンスとなりましたが、この部分は学習したいパラメータを含まず最適化するわけではないことから、無視できる項となります。第2項は逆拡散過程で入力データを再現するパートに対応しており、入力が画像の場合には各データの数値が整数となるようにする処理などが必要になります。
肝心の第3項について、各時刻 での拡散過程の事後確率と逆拡散過程(今回学習したい方)の間のKLダイバージェンスを下げる問題となっていることが分かります。ポイントは、各時刻 についてその一つ前の時刻の状態を推定する問題を解く形式となっていることです(正解の逆反応 をモデル で推定する)。これにより「全体の拡散を元に戻すプロセスを学習する」という大きな問題が「各時刻での拡散を元に戻すプロセスを学習する」という小さな問題へ分割することが出来ていると言えます。このような形式とするために、 の形だった目的関数をわざわざ逆さにしたというわけです。
こうして、基本的に各時刻で拡散により加えられたノイズのデノイジングを行うモデルの学習を、KLダイバージェンスに従って行うことで、全体の生成プロセスの学習ができることが分かりました。これが拡散モデルにおける学習の基本となります。ここではガウス分布をノイズとしたときのモデルと学習についてまとめます。
拡散過程 で加えるノイズの平均は一つ前の時刻の出力を少しずつ減衰させたものとし、分散はどんどん増やすものとします()。平均を減衰させているのが少し以外かもしれないですが、全くノイズがない状態から完全な標準ガウス分布への変遷を行っていることを考えると、滑らかにその間を繋げるためにこの形式をとる必要があることが想像できます。
これはガウス分布の再生性を帰納的に用いることで示せます(詳細は割愛)。これを用いて事後確率もガウス分布で表せることが示せます。
これらを用い、逆拡散過程のデノイジングに対応するガウス分布の平均・分散をこれらに沿うように推定します。ただし、分散自体は か をそのまま利用してしまう(推定する必要がそこまでない)ことが多いようです。したがって、以降での分散は という既知の値を使うこととします。
なお、逆拡散過程もガウス分布を用いることが出来る正当性については何も触れてきませんでした。実はこのあたりの議論も、非平衡統計力学の研究に基づいて正当化されており、十分な時間の極限を考えることで拡散過程と逆拡散過程で用いるノイズの形式(確率過程)が一致することが示されます。これについても、可能な限り末尾の追記にそのうち記載しようと思います。
DDPM
さて、目的のKLダイバージェンスのうち、予測したい平均 に関連する部分だけ取り出すことにしましょう。今、ガウス分布同士のKLダイバージェンスを考えていますが、ガウス分布のうち平均が現れるのは指数部分のみなので、残る正規化定数部分についてはココでは無視することが可能です。
KLダイバージェンスの式に代入すると
(最後の式変形は結構飛躍していますが、)このように、最小二乗法と同じ形式の損失関数となりました。実際、最小二乗法がガウスノイズのフィッティングを行っていると考えれば自然な結果だと言えます。
さて、拡散過程における平均 はどのような形状の関数なのでしょうか?予測する関数 は のみを入力とするため、 から を除いた形の関数を求める必要があります。ありがたいことに拡散過程の条件付き確率 はすでに計算してあり、デノイジングをこの式から逆算することが出来ます。
これを の式に代入してちょっと計算すると
と求まります。したがって、 はこの式におけるノイズ部分 を推定することが本質となります(デノイジング)。
最終的に、拡散過程において現時点までに加えられたノイズ を として予測するタスクだと考えられます。このように拡散モデルにおけるデノイジング過程を学習するようなアルゴリズムを、Denoising Diffusion Probablistic Model (DDPM) と呼びます。
さて、この形式の目的関数に見覚えはあるでしょうか?NCSNの目的関数を再掲します。
この二つは同じ形式をしており、定数倍の違いを除いて同じ処理で学習できることが分かります。つまり、この二つの手法は数学的には等価な枠組みのもとに語ることが可能です。
また、DDPMによるサンプリングは逆拡散過程を学習したデノイジングに分散 分のノイズを加えることで行いますが、これはNCSNにおける異なるノイズにおけるランジュバンの拡散過程の連続と同じ形式をとります。このように、DDPMとNCSNはその本質が近いものとなっています。
実際、NCSNは元のデータ分布をガウス分布で擾乱した後の分布について、そのデノイジングを行う枠組みであり、ノイズは標準ガウス分布に従う変数 を用いて
と書けました。一方、DDPMは
と書けます。この二つはどちらも元のデータ(シグナル) に、ノイズ が乗っているという形式ではありますが、それぞれにかかる定数倍の比率=シグナルノイズ比(SNR)が異なります。完全なガウス分布を作成するにはSNRが0となる必要があり、NCSNなどスコアベースモデルでは、(理論上)無限幅のノイズを用いる必要があるため、分散発散型のモデルと呼ばれます。一方、DDPMは最終的に分散を1に抑えているため、分散保存型モデルと呼ばれます。
これまでの議論からも察することが出来ますが、これらモデルの目的関数は、SNRの違いが定数倍に効くものの、本質的な形状は二乗誤差であり、学習方法は共有できることが計算から求められます。さらに、今考えているモデルはすべて時刻の離散的なものだったのですが、時間に対して連続極限をとることで、NCSNなどの分散発散型モデルとDDPMなどの分散保存型モデルとが一致することが示されています。すなわち、極限的には二つのモデルは数学的に等価であり、状況に応じて好きな方を用いてよいことになります。
このように、スコアベースモデルと等価でありつつ、画像入力に対する拡散過程と逆拡散過程をモデルとした拡散モデルのデノイジングとして定式化されたDDPMは、Stable Diffusionの根本メカニズムとして利用されます。次回からは、ついに Stable Diffusion の解説を徐々に始めることになると思います。Stable Diffusion に至るまでの流れや、各種メカニズム(分類類器なし条件付け、拡散回数の削減など)に関連する論文も紹介し、その仕組みを解説します。その後は、LoRAやControlNetといった追加機能に関する解説も出来る範囲で行う予定です。