もし生物情報科学専攻の学部生が "StableDiffusion" を理解しようとしたら 3 ~Transformer~
前回
cake-by-the-river.hatenablog.jp
今回は、今後頻繁に利用される自然言語処理のアーキテクチャである Transformer を紹介します。(Multi-Head) Attention からザックリですが解説を試みます。
Transformer
元論文:
arxiv.org
Transformer は Google が 2017年に発表したモデルで、それまで LSTM や GRN など RNN ベースのモデル(および Encoder-Decoder モデル)が主流だった自然言語処理の深層学習に対して、Attention メカニズムのみを利用した単純な構造であり、より多くの並列化も可能なため、上記のモデルに比べてGPUを効率的に利用でき、高速かつ高い精度を得ることが出来るものです。
以下の説明では、次に挙げる記事も参考にしました。
qiita.com
qiita.com
deeplearning.hatenablog.com
Attention
Attention メカニズム(注意機構)はもともと、自然言語処理における Encoder-Decoder モデルの持つ問題点を解決するために、Bahdanau et al., 2015 などで導入されました(それ以前にも Mnih et al., 2014 など他の分野でちょくちょくみられるようですが)。
(自然言語処理における)Encoder-Decoder モデルとは、機械翻訳など入力文字列から望みの出力文字列を生成する問題に対し、入力を受け取って文脈をベクトル化するRNN(Encoder)と、文脈を利用して出力を生成するRNN(Decoder)を連結したニューラルネットワークを用いたもので、Seq2Seq(Sutskever et al., 2014)などで導入され始めました。
このアプローチにおいて、Encoder への入力は(その仕様上)固定長の文脈ベクトルに圧縮されます。その結果、長い文章などを入力した場合に情報を上手く保持できず、精度が落ちることが知られていました。
Encoder は入力を逐次的に処理し再帰的に構築するため、中間状態(各時点でのRNN内の隠れ状態)が沢山できますが、利用するのは最終的にすべてを読み終わった後の文脈ベクトルのみです。Bahdanau らは、これらの中間状態を全て用いつつ、出力の各時点で重みを適切に変えた中間状態の「寄せ集め」を活用する方法を提案しました。すなわち、出力する文章の各単語ごとに、入力された文脈のうち一部のみに"注意"をむけて利用する、という仕組みを用いました。
このとき、最終的に出力に利用する文脈ベクトル は、Encoderの各時点での隠れ状態 "annotation" を重み により平均化した
となります。
問題はこの重み付けをどう決めるかですが、これは、入力のもつ(色々な時点での)文脈のうち、直前で出力した時点の文脈 にできるだけ近いものを取ってくるようにしたいはずです。これを、機械翻訳の例で具体的に考えてみます。ある英語の文章をフランス語の文章にしたいとして、ある程度翻訳した後、次に出力すべきフランス語の単語を予測します。ここで、直前に出力したフランス語の文章の文脈と類似した元の英語文章中の時点の状況に注意を向けることで、その英語の文脈の後に現れた単語と類似するフランス語の単語を高く評価すればよいことが分かります。そしてこのアプローチは、いったん英語の文章をすべて通読して全体の(ある程度均質化した)文脈から予測するよりも、確かによさそうだと言えます。
これは、Bahdanau らが論文中で "alignment" として定式化しています。(ここからはお気持ちです)アラインメントというとバイオインフォマティクスでもよく利用してきたテクニックです。ヒトの遺伝子配列とマウスの遺伝子配列をアラインメントする、すなわち対応する箇所を並べて揃える処理は、英語の文章とフランス語の文章を並べるのと同じで、各単語(遺伝子の場合は DNA塩基 or アミノ酸)の文脈が揃うような箇所を検索しています。バイオインフォマティクスの場合は、揃う箇所に高い数値(重み)を与えるモデルとして、統計力学的なスコア(≒エネルギー)関数のカノニカル分布(softmax)を考えますが、Bahdanau らの論文でも、同様に何かしらのスコア関数の softmax を計算するようにしています。
さて、問題は「どのように文脈同士のスコアを決めるか」の形になりました。Bahdanau らは フィードフォワードのニューラルネットワークを用いて求めています。しかし、別にそうする必要はなく、Luong et al., 2015 では内積を取るシンプルなアプローチ Dot-Product Attention が提案されました。
Scaled Dot-Product Attention, Multi-Head Attention
次に、Attentionが辞書の形式を取っていることを説明し、Transformerでの使われ方を見ていきます。
アラインメントとして定式化したAttentionですが、アラインメントはある文脈を別の文章中から"検索"することに対応していました。この構図は、データ構造としては辞書に対応します。つまり、とある query を文脈を保持した辞書内で検索し、(複数の) key に対応する(複数の) value を受け取るわけです。例えば、入力された英文の各時点の文脈は、まず今欲しいフランス語の文脈 query に検索されて、各英語の文脈 key ごとに関連度の重み(Attention)の形で出力されます。そして、それぞれの文脈の実際の値 value を重み付き平均することで、対応するフランス語の文脈を獲得します。
Transformer では RNN を利用せず、辞書という構図のみを抽出してくるため、今まで RNN の隠れ状態としてきた を query , key , value と呼ぶ行列で考えます(上の例では key と value は両方 に対応する)。RNNの文脈で、再帰性のため query は1時点のみとなりますが、そのような制約はなくなったので、複数の query を同時に入力できます。なお Transformer では、内積が高次元で大きくなりすぎる問題を解決するため、適度に key の次元 でスケールした Scaled Dot-Product Attention が利用されます。行列計算の数式にすると、以下の形で書けます。
さらに、これらの辞書は並列化して計算することが可能です。 Multi-Head Attention は、 を線形変換により少しずつ異なる複数の空間(head)に分け、各 head の Scaled Dot-Product Attention を計算したうえでそれを再び統合する形を取ります。
具体的には、それぞれの head に対し、Attentionを行列 を用いた線形変換(Linearレイヤー)を用いて
として求め、再び線形変換 によって統合します。
こうした分割はそこまで計算量を増やさずに複数の空間でのAttentionを求められるという利点があるようです。
また、Transformer では Encoder-Decoder モデルで用いられたのと同様の、 Decoder の出力を query (target), Encoder の出力を key, value (source) とするものだけでなく、query, key, value すべてを同じ文字列から生成する Self-Attention という仕組みも Encoder, Decoder パートで用いています。これは RNN などを置き換えるのに利用していますが、(上記の説明からわかるように)Attentionの計算は逐次的ではないため並列に計算すればかなり実行時間を削ることが出来ます。そして文章全体を毎回見ることで RNN や CNN で問題となる情報の流出もしにくく、精度が上がりやすくなります。また、Attentionの可視化によってモデルの解釈もしやすくなると言えます。
Transformer のアーキテクチャ
さて、実際の Transformer のアーキテクチャを見てみます。
いくらか紹介していない部分について軽く説明します。
まず、Positional Encoding ですが、これは Transformer が RNN をやめて Attention を用いることにしたため、導入されています。Attention では、計算の際に文中の順序の情報を用いない(辞書は本来順番のない集合を扱う)ため、語順を明示的に入力に加える必要があるというわけです。続いて、灰色で示された各ブロック内の FFN について、ここでは各単語ごとに計算されるものとなっています。基本的に Encoder-Decoder モデルの形状はそのまま、中身を Attention 満載のものに変えたという感じです。
正規化など
Transformer では、"Add&Norm" と称されるレイヤーがあります。これは、残差接続(Residual Connection)と呼ばれているもので、ResNet (He et al., 2016) などで有名になりましたが、前回の FCN(U-Net) でも登場した スキップ接続 とほとんど近いものとなっています。"残差"と呼ぶのはスキップした層で学習する内容が、元の入力を"修正"するものとなる形状をしているからです。つまり、迂回前の入力 が迂回された層(今回は Multi-Head Attention など)の結果 と足しあわされた出力 に対し、 という差分の形で書けるということです。
残差接続は、誤差逆伝播のときの勾配消失に対してロバストだと言えます。なぜならば、迂回された層でたまたま勾配が 0 となってしまっても、迂回ルートの方は勾配が残るために、より上流まで勾配が流れやすくなるからです。制約が緩いために仮に学習に失敗しても部分的にしか影響しないという風にも捉えられます。
さらに、Transformer では至る所にレイヤー正規化も用いられています。これは、バッチ正規化など一般的に複数の入力文字列データで正規化するのではなく、各文字列データごとにその中の単語全体で正規化します。これに Dropout も加え、過学習を抑える工夫がなされています。
こうしてTransformerは機械翻訳に限らない様々なタスクに汎用可能なモデルとして利用されるようになります。
コードの紹介
Transformerに関しては疑似コードを載せるにはかなり大きいため、PyTorchの実装が為された文献を紹介するにとどめておきます。決して面倒だったからではないです。
次回は、前回までの画像処理と今回の自然言語処理の両者が融合する流れとなった ViT と CLIP を紹介しようと思います。