SE(3)不変とタンパク質立体構造予測④:SE(3)同変な深層学習

前回、SO(3)同変な関数を考えるための数学的基礎として、球面調和関数を用いた関数の展開とClebsch-Gordan係数による次元の変換を扱いました。

cake-by-the-river.hatenablog.jp

今回は、これらを陽に用いた機械学習フレームワークとして、SE(3)-Transformers の理解を目指します。

Tensor Field Networks

Tensor Field Networks は Google が2018年に発表した、3次元空間を入力としてSE(3)同変性を保つニューラルネットワークです。

arxiv.org

前回までに抑えた基礎をCNNの構造と類似させることで活用しています。

pointwise convolution

CNNでは、周辺画素の情報をフィルター(カーネル W_{ij} の畳み込みによって計算します。例えば3×3のフィルターを用いた場合、その点の画素と上下左右斜めの8画素を重み行列(カーネル) に基づいて適切に足し合わせて出力とします。

 \displaystyle \sum_{i,j} W_{ij} x_{i,j}

一方、Cα原子の集合など、3次元空間上の点群を入力としたニューラルネットワークを考える場合、必ずしもそれらの点が画像のような格子点上に配置されているわけではなく、それぞれの点の間の距離は連続値を取ります。このような連続した距離をもつ隣接関係の中でもCNNが持つ「畳み込み」のような演算を行うことを考えたいですが、そのままではイマイチつかみどころがない気がします。

CNNの畳み込みを連続的な点群に応用するには?


ここで、ある発想の転換をします。ある範囲の画素を入力とするフィルターを考えましたが、この時の隣接関係を実際に連続的な距離に基づく関係に置換しても等価になるようにしてみます。例えば、画素間の基本距離単位が1であるとして、3×3のカーネルに相当するものとして、中心からの距離  r \sqrt{2} 以下の点は 1、そうでない点は 0 となるような関数を用いてみます。すると、確かに周囲9画素だけが計算範囲となるように設定できます。この関数を用いて各画素の特徴量を計算することは、確かに3×3の範囲の計算を行うのと等価と言えます。

同じように3次元空間上の点群に対し、中心からの距離(動径)を関数としたフィルターのようなものを考え、そのフィルターを基に他の点との関係値(重み)を求め、その重みを基に特徴量の畳み込み演算を行うことは、2次元CNNにおける畳み込みと同じ類の計算ではないでしょうか?


「いや、動径関数だけを使ったフィルターだと等方的にしかならなくてCNNみたいに3×3の各画素で異なる重みを付けることはできないのでは?」というツッコミが出てきそうです。実際、動径関数のみを使う場合はそうなってしまいます。しかし、私たちはすでに異なる方向で異なる値を取る関数をコンピュータ上で扱う方法を知っています。球面調和関数です。球面調和関数とはもとより球面上での異方性をベクトルで表現できる関数として導入しました(電子軌道を思い出してください)。そのため、ある程度の複雑さの異方性であれば、対応する低次元の既約表現を基にその異方的な特徴を抽出することは可能です。


これが、Tensor Field Networksのキモである pointwise convolution の基本的な発想です。

CNNのカーネルは周辺画素の重み付き和の形で定義しましたが、TFNでは各点で共通する動径関数と球面調和関数の積をフィルターとして扱い、畳み込み演算をします。実際の計算は点ごとに行うため、pointwiseと呼ばれるわけです。フィルタの形も非常に単純で、動径関数  R と球面調和関数  Y_m^{(l)} を用いて

 F(r,\theta, \phi) = R(r) Y_m^d({(l)}(\theta, \phi)

と書くだけです。動径関数  R(r) を学習のメインとして据えることで、CNNのフィルターの学習のように使うことができます。TFNでは動径関数を複数のガウス関数の和として表して利用しています。

SO(3)同変な畳み込み

このようにTFNにおける畳み込みの基本を導入しましたが、これらの計算がSO(3)同変であるように拡張していく必要があります。

フィルター自体は球面調和関数の線形和でしかないため当然SO(3)同変であり、それらの畳み込み演算も重み付き和だから同変なのでは?と思うかもしれないですが、CNNの畳み込み演算を考えると、フィルターの適用によって出力される表現ベクトルの次元が変化する可能性がありました。例えば、U-NetのようなCNNを拡張したニューラルネットワークにおいては、入力ベクトルの次元を拡張するdeconvolutionという演算が使われています。今回の場合、 l=0 の表現に対応する入力ベクトル(すなわち、スカラー)に対して  l=1 のフィルターをかませた場合、出力される表現空間は  l=0 のみならず、 l=1 にも対応した 4 次元のベクトルになりそうです。


このような懸念も、前回導入したClebsch-Gordan係数を用いればすぐに解決できます。CG係数は異なる次元の既約表現に対応した表現空間の間を結びつけることができるため、それを畳み込み演算に組み込むことで解決するわけです。


こうして pointwise convolution レイヤーを導入できます。入力である特徴量は各軌道量子数  l に対応したものをまとめたものとして用意され、それぞれが(点群の数  a, チャネル数  c, 既約表現の数  m)のテンソルからなります。

 input = \{ V^{(l)}_{a,c,m} \ \ ( l = 0, 1, ... ) \}

もちろん l の上限は問題ごとに適切に決めます。レイヤー内の計算は

 \displaystyle L_{a,c,m_o}^{(l_o)} (r_a, \theta_a, \phi_a, V_{a,c,m_i}^{(l_i)}) = \sum_{m_f, m_i} Q_{(l_f, m_f), (l_i, m_i)}^{(l_o, m_o)} \sum_b F(r_{ab},\theta_{ab}, \phi_{ab}) V_{b,c,m_i}^{(l_i)}

となります。ここではCG係数  Q_{lk} に含まれる変数を明示的にすべて書いています。大変長くキモく見えますが、行っている計算はあくまで「相対ベクトルを使ったフィルタ  F と 他の点の特徴量  V の重み付き和」を「CG係数で適切な次元に振り分ける」程度のことです。


ところで、CNNでは画素間の関係性を計算するフィルタのみならず、チャネル間の関係性を計算することもありました。そのような計算は 1x1 convolution と呼ばれていますが、TFNでも同様の計算を行うことができます。イメージとしては同じ点内に含まれている現在の特徴量同士の関係性を学習しようとしているため、self-interaction レイヤーと呼んでいます。実際の計算は、各チャネル同士の重みつき和からなるため、特に難しくはないです。


これらのレイヤーに、concatenationレイヤ-(CG係数を基にバラバラに出力される特徴量を同じ次元でまとめるレイヤ-)や nonlinearレイヤ-(非線形関数をかませるレイヤ-)を組み合わせてTFNを作ることができます。

例えばテトリスのような形状を分類する問題では、

  • 入力は各ブロックの位置のみ(スカラー l=0 の情報)
  • 分類に必要な情報は、他のブロックが相対的にどこに存在しているかで十分(ベクトル➡ l=1 の情報)
  • 出力はどの形状かを当てるだけ(スカラー l=0 の情報)

となるため、内部で  l=1 にdeconvolutionしつつ l=0 に戻す形状のニューラルネットワークで十分です。実際、そのようなおもちゃモデルが公開されています。他にも論文中では、ニュートン力学を学習したり、分子の生成を学習させています。

テトリス風のブロック分類問題を解くTFN。論文より引用。


ここまで、作ったレイヤーが本当にSO(3)同変かどうかを計算的に確かめることをしませんでしたが、論文では実際に数式と図を用いて証明しています(記事末尾でも扱う予定です)。

SE(3)-Transformers

arxiv.org

実は、TFNを理解できた時点で、SE(3)-Transformersを理解するためのポイントは8割程度抑えてしまっています。

  • 各点同士は動径関数と球面調和関数に分けてフィルターを作り、動径関数を学習する
  • CG係数を使って異なる軌道量子数の表現空間同士は接続する

このうち、前者のポイントは実は Graph Neural Network (GNN) の導入そのものに対応していたりします。つまり、各点どうしの距離を用いて点の間の情報を交換するやり方が、そっくりそのままGNNというわけです。


TFNでは各点での重みは動径関数と球面調和関数、およびCG係数のセットから構成しました。このとき、すべての点でこの関係は同じものとしてみなしていましたが、実際には隣り合う点の関係性をもっと直接的に組み込んだ重みを導入したいと考えることがあります。例えば、点群に文脈のようなものがあるとみなすならば、その文脈を反映した重みづけを採用したいです。

文脈を適切に取得するアーキテクチャとしてもっとも有名なものは、Attention機構です。つまり、各点同士で key-value に基づくマッチ度を求め、それに応じて attention 値を設定し、それを重みとした計算をすることができます。これが SE(3)-Transformers です。想像よりはシンプルな作りです。

SE(3)-Transformersの解説。論文より引用。

まずは近傍の点の特徴量を(閾値等で)取得し、それぞれの点との相対距離と角度を用いて、動径関数・球面調和関数・CG係数から重みを求めます。今回は、この重みをAttentionにおける key, value, query の行列として用います。これらのレイヤーも、本質的には球面調和関数とSO(3)の表現論、および相対ベクトルを用いているため、SE(3)の同変性を獲得することができています。これまた論文に証明もあるので、興味がある方は是非読んでみてください。

RoseTTAFold

RoseTTAFoldでは、特徴ベクトルをMSAを用いたTransformerによって学習していました。これらの特徴は、各アミノ酸の点の情報とそれぞれのペアの情報の二つからなります。RoseTTAFoldでは、アミノ酸におけるN, Cα, C原子の位置情報(初期値あるいはそれまでの計算で得られているものとします)と二つの情報を基に、近傍グラフを作成し、それをSE(3)-Transformersにかけることで情報の更新を行うことにしています。特に、SE(3)-Transformersでは  l=1 の情報としてベクトルの情報を得ることができるため、これを基に新しい位置を推定しなおします。RoseTTAFoldで使われるSE(3)-Transformers内部でも、 l=0, 1 の計算が利用されています。

RoseTTAFoldにおける特徴ベクトルとSE(3)-Transformerの融合。論文より引用。


これがRoseTTAFoldにおける立体構造推定モジュールの主要な部分です。実際には、5回このプロセスを繰り返して修正した結果が用いられています。


ここまで、RoseTTAFoldを代表とする立体構造予測ツールにおけるSE(3)不変なニューラルネットワークを扱いました。次回は、Invariant Point Attention と呼ばれる別の仕組みを用いたAlphaFold2の立体構造予測モジュール、および今後の進展を合わせた最終回とします。

補足

TFNレイヤーのSO(3)同変性の証明

いずれ書きます。