もし生物情報科学専攻の学部生が "StableDiffusion" を理解しようとしたら 2 ~U-Net~

前回

cake-by-the-river.hatenablog.jp


今回は、細胞など医用画像のセグメンテーション(画像の中から細胞の部分のみを取り出す)で用いられる U-Net を見てみます。途中で FCN についても扱います。

U-Net


元論文:
arxiv.org


前回のAlexNetは、少ないサンプルでしか適用できなかったCNNを様々なテクニックにより大規模データセットにも利用できるようにしました。しかし、医用画像処理などの文脈では次の問題が残ってしまいます。

  • 得たい出力は、画像の分類ではなく、画像内の特定の領域の抽出である(より正確には、元の画像で対応する領域のピクセルを選ぶ semantic segmentation)
  • 正解データは(人手の問題で)沢山は得られない
U-Net で得られるセグメンテーションの結果(論文より引用)

このようなタスクに対し、MCDNN (Cirsan et al., 2012) などの手法が提案されていました。これらは、画像の局所的な領域(パッチ)ごとに予測を行い、その結果を統合するもので、実質的な入力数を増やせるなどの利点がありましたが、局所性に振ることで全体的な文脈の把握が困難となるというトレードオフや実行速度がネックとなっていました。

U-Net (Olaf et al., 2015) では、"Fully Convolutional Network (FCN)" (Long et al., 2014) と呼ばれる全く異なるアプローチを改良することにより、これらの問題を解決しました。

FCN

arxiv.org


FCN はその名の通り、AlexNet や VGGNet が畳み込み層と全結合層からなるアーキテクチャであるのに対し、全ての層を"畳み込み"層としたものです。また、特徴を抽出する前半の畳み込み層については、既存の画像分類などに用いられる良質なモデルのパラメータをそのまま転用できます。

FCNが畳み込み層を転用する図(論文より引用)

例えば AlexNet は 5 つの畳み込み層の結果を単に flatten 処理することで長さ 4096 の1次元ベクトルとし、全結合層の処理によって 1000 次元のベクトルを出力します。しかし、flatten の処理にせよ、全結合層の処理にせよ、行っている操作は形状変化であり、畳み込みとして置き換えることが可能です(convolutionalization)。

これにより、224×224 に限定されていた入力画像のサイズを可変長にすることが出来ます。それは、畳み込み処理は局所的なフィルター(関数)に過ぎないので、全体の大きさに依存しないからです。なお、終盤の畳み込み処理は各画素ごとに、すなわち 1×1 のサイズのカーネルを用いており、点単位畳み込みなどと呼ばれ、応用されているようです。


FCNの論文では、実際に 500×500 のサイズの入力画像から 10×10×チャネル数 のテンソルを出力し、その速度が、単純に入力画像を 10×10 個に分割して AlexNet を入力するのに対して、 5 倍ほど速いことを指摘しています。これは、MCDNNなどパッチ予測よりも速度向上ができることを示していると言えます。

Deconvolution

しかし、このままでは前半の畳み込みでダウンサンプリングされた小さな出力画像しか得られず、元画像から領域を抽出する、といった高解像度のセグメンテーションには利用できません。

そこで、FCN では deconvolution (逆畳み込み, 転置畳み込み=Transposed convolution, ...)と呼ばれる層を学習させることでアップサンプリングすることにしています。基本的な考え方は畳み込みの逆操作であり、入力した行列の各要素ごとにフィルターをかけてカーネルのサイズに拡大します。

Deconvolution


より具体的に計算方法をまとめてみます。

まず最も基本的な畳み込み操作の具体的な計算方法について考えます。入力画像を (縦×横, チャネル数) のサイズの行列に並べなおし、カーネルに対応する行ベクトルを並べた行列と積を取り、再び画像の形に並べなおすことで、畳み込みと同じ計算をすることが出来ます(さらに定数を足すことも)。ここで、回りくどく見えますが、カーネルをあえて入力画像と同じサイズとし、実際に適用される場所以外を0で埋めるスパースなものとして扱います。

Conv2Dのスパースな行列積表示

一方、deconvolution は入力画像の各要素ごとにカーネルを定数倍し、畳み込みをする前の画像の形状のテンソルに足していくことで求めます。これは、上の行列積を単に転置したもので実装できることが分かります(じっくり見てみると確かにそうなっています)。

deconvolution のスパースな行列積表示

実際、Conv2D の backward 関数と deconvolution の forward 関数はほとんど同じとなり、 Conv2D の forward 関数が deconvolution の backward 関数と対応します。このように計算上転置で書けることから、転置畳み込みとも呼ばれるわけです。

こうして、入力画像のサイズをアップサンプリングする計算方法が分かりました。そして、Conv2Dと表裏一体であることからわかるように、パラメータを学習することも可能です。なお、この転置畳み込みによる画像の"生成"は、DCGANなど Decoder の基礎として利用されるようになっていきました(前半の畳み込み層は Encoder と捉えられます)。

FCNのアーキテクチャとスキップ接続

FCNは、このように 畳み込み層→点単位畳み込み層→転置畳み込み層 の形ですべて"畳み込み"からなるアーキテクチャを持ちます。論文ではさらに、アップサンプリングの際にそのまま元の画像の大きさにするのではなく、前半の畳み込み層で得られる特徴マップを利用して精度を向上させられることを指摘しています。すなわち、(入力画像より小さな)特徴マップのサイズに一旦アップサンプリングしたのち、特徴マップと足し合わせたものを次のアップサンプリングに利用します。このような接続をスキップ接続と呼びます。

スキップ接続(論文より引用)

このようにスキップ接続を用いることで、確かに精度が向上したことが報告されています。

スキップ接続の効果(論文より引用)
U-Netのアーキテクチャ

U-Net は FCN の転置畳み込み層をより多くし、U字状に書けるアーキテクチャとなっています。up-conv が転置畳み込みに、conv 1×1 は点単位畳み込みに対応します。また、スキップ接続の際、(出力サイズに応じて切り取られた)特徴マップをアップサンプリングしてきたテンソルに連結し、畳み込みを行うようにしています。

U-Netのアーキテクチャ(論文より引用)

なお、FCNと同様に、U-Netは任意のサイズの画像を入力とすることができ、任意のサイズのセグメンテーション画像を出力とすることが出来ます。

他にも、生物学的によくある変形を駆使することで data augmentation を強くし、少ないサンプル数でも精度が出るような工夫があったり、同じクラスの細胞のセグメンテーションのために、重み付きの損失関数を導入するなどの工夫もあるようです。

細かい点はさておき、畳み込み(Encoder)と転置畳み込みによるアップサンプリング(Decoder)を一つのネットワークに組み込み、特徴マップをスキップ接続するというアーキテクチャが提案され、その後の多くのアーキテクチャの基盤となりました。

U-Net アーキテクチャのPyTorchコード

今回も、適当な 疑似 PyTorch コードを載せようと思います。

x1 = doubleConv(x, 3, 64)
x2 = down(x1, 64, 128)
x3 = down(x2, 128, 256)
x4 = down(x3, 256, 512)
x5 = down(x4, 512, 1024)
x6 = up(x5, x4, 1024, 512)
x7 = up(x6, x3, 512, 256)
x8 = up(x7, x2, 256, 128)
x9 = up(x8, x1, 128, 64)
y = Conv2d(x9, in_channels=64, out_channels=cls_num, kernel_size=1)

doubleConv(in_channels, out_channels) = Sequential(
    Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
    BatchNorm2d(out_channels), ReLU(),
    Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
    BatchNorm2d(out_channels), ReLU()
)

down(in_channels, out_channels) = Sequential(
    MaxPool2d(kernel_size=2),
    doubleConv(in_channels, out_channels)
)

up(x2, in_channels, out_channels) = Sequential(
    ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2),
    crop_and_concat(x2),
    doubleConv(in_channels, out_channels)
)

なお、U-Net の up-conv は転置畳み込み以外のアップサンプリングの方法(バイリニア補完など)を使うこともあるようです。

(参考)
github.com

次回は、Transformer について扱います。

cake-by-the-river.hatenablog.jp