社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

PytorchによるImage Segmentation(1)

はじめに

Image Segmentationは画像を領域ごとに分割する技術であり、Conputre Vision技術の中でも特に汎用性の高い技術です。
応用範囲は非常に広く、医療画像解析や自動運転技術、工場の自動化、ロボティクス等々の様々な分野に応用されています。

以下に、Cityscapesと呼ばれる、Image Segmentationの代表的なデータセットの画像を示します。
f:id:nsr_9:20210901141618p:plain

このデータセットは、フロントカメラによって撮影された走行画像で構成されています。
ターゲットのラベルは、ピクセル毎にクラスラベルが付与されています。
具体的には、画像中で車が写っているピクセルの位置には車ラベル、道路が写っているピクセルには道路ピクセルを付与する〜というイメージです。
多分、言葉で説明するより、絵を見たほうが直感的に理解できると思います。

Image SegmentationにはCityscapesの様な画像を複数のカテゴリに分割するMulti-Class Segmentationと、ターゲットのカテゴリとそれ以外に分割するBinary Segmentation(って呼んでるんですけど一般的なんですかね・・?)があります。

Multi-Classの例 Binaryの例
f:id:nsr_9:20210901141618p:plain f:id:nsr_9:20210901145442p:plain

[1] Cityscapes Dataset – Semantic Understanding of Urban Street Scenes
[2] The STARE Project

今回からこのBinary SegmentationをPytorchで実装して行きたいと思います。

U-Net

Binary Image Segmentaionは一般的な画像分類と異なり、ピクセル単位で画像分類を行う必要があります。
その為、Networkの構造も画像分類のものから大きく改造する必要があります。
Image Segmentaion用のNetworkモデルは、日々新たなモデルが公開されています(いました)が、U-Net[3]と呼ばれるVGGベースのEncoder-Decoderモデルが最もシンプルで有名だと思います。

ネットワーク構造は次のような感じで、左から入力された画像を段階的に4回Down Samplingした後に4回Up Samplingし、最終的には入力画像と同じ解像度の出力を得ます。
f:id:nsr_9:20210901150932p:plain

このDown Sampling部分をEncoder Blockと呼び、UpSamplingする部分をDecoder Blockを呼びます。
U-NetはDecoder BlockでUpsamplingする際に、Encoder Blockからヒントを得る事で高解像化を実現しています。(これをSkip Connectと呼ばれ、多くのEncoder-Decoder Modelに応用されています)

このネットワークモデルをpytorchで実装すると以下の様になります。
実装する際は、以下のgithub projectを参照しました。
GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

import torch
import torch.nn as nn


class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(VGGBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.activation(x)
        return x

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool1 = nn.MaxPool2d(2)
        self.vggblock1 = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x):
        x = self.pool1(x)
        x = self.vggblock1(x)
        return x

class UpSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.vggblock = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
        x = torch.cat([x2, x1], dim=1)

        return self.vggblock(x)


class UNet(nn.Module):

    def __init__(self, classes):
        super(UNet, self).__init__()

        self.conv1 = VGGBlock(3, 64, 64)
        self.down1 = DownSampling(64, 128)
        self.down2 = DownSampling(128, 256)
        self.down3 = DownSampling(256, 512)
        self.down4 = DownSampling(512, 512)

        self.up1 = UpSampling(1024, 256)
        self.up2 = UpSampling(512, 128)
        self.up3 = UpSampling(256, 64)
        self.up4 = UpSampling(128, 64)
        self.out = nn.Conv2d(64, classes, kernel_size=1)


    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.out(x)
        return x

pytorchでNetworkモデルを作成する際は、Block単位(幾つかのConvolution LayerやPooling Layerで構成された小さなNetwork)でNetworkを定義し、それを積み木のように組み合わせて定義する事ができます。

作成したネットワークを出力すると次のようになりました。

UNet(
  (conv1): VGGBlock(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU(inplace=True)
  )
  (down1): DownSampling(
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (vggblock1): VGGBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (down2): DownSampling(
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (vggblock1): VGGBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (down3): DownSampling(
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (vggblock1): VGGBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (down4): DownSampling(
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (vggblock1): VGGBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (up1): UpSampling(
    (up): Upsample(scale_factor=2.0, mode=bilinear)
    (vggblock): VGGBlock(
      (conv1): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (up2): UpSampling(
    (up): Upsample(scale_factor=2.0, mode=bilinear)
    (vggblock): VGGBlock(
      (conv1): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (up3): UpSampling(
    (up): Upsample(scale_factor=2.0, mode=bilinear)
    (vggblock): VGGBlock(
      (conv1): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (up4): UpSampling(
    (up): Upsample(scale_factor=2.0, mode=bilinear)
    (vggblock): VGGBlock(
      (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU(inplace=True)
    )
  )
  (out): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

このネットワークを絵にすると次のようになります。

f:id:nsr_9:20210901180143p:plain

一気に学習コードまで作ろうかと思ったのですが思ったより大規模な変更が必要だったので、数回に分けてBinary Image Segmentationを実装していきます。

[3] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.