社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

PytorchによるImage Segmentation(3)

はじめに

nsr-9.hatenablog.jp

PytorchによるU-NetのImage Segmentationの続きです。
前回はDataLoader部分を実装したので、今回はLoss Functionを実装して実際に学習、推論をやっていきます。

Image SegmentationのLoss Function

2クラスのImage Segmentationは、普通のCategorical Cross EntropyなどのLoss関数ではうまく学習できない事が多いです。
分類するクラス(例えば車両と背景)のピクセル数の出現頻度が極端に偏ってしまう事が原因です。

Cityscapesから適当に画像を参照し、車両と背景のピクセル数をカウントすると、次のようになります。

入力ラベル 出現頻度
f:id:nsr_9:20210905044701p:plain f:id:nsr_9:20210905044715p:plain

車両と背景では出現頻度が20倍近く異なります。
この様に出現頻度が大幅に偏っている状態だと、「全部背景カテゴリである!」と適当に推論してもそれなりに正解になってしまい、うまく学習を進める事ができません。

その様なアンバランスな状態を解消するLoss関数がDice Loss Function[1]です。
Dice Lossは次のような式で表されます。

f:id:nsr_9:20210905052744p:plain

直感的な表現をすると、予測値または推測値に1が含まれる(車両に関係する)ピクセルのみを使ってLoss値を求めています。

f:id:nsr_9:20210905052909p:plain

こうすることで大多数を占める予測値も推測値も背景であるピクセルを評価から除外する事ができるので、「とりあえず全部背景!」という適当な推論をしても正解にならなくなります。

Dice Loss Functionですが、いい感じに実装を公開してくれている方がいらっしゃったので、今回はそちらを参考に実装しました。
pytorch-unet/loss.py at master · usuyama/pytorch-unet · GitHub

from torch import nn
import torch.nn.functional as F
import torch


def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    return loss.mean()


def calc_loss(pred, target, metrics=None, bce_weight=0.5):
    # Dice LossとCategorical Cross Entropyを混ぜていい感じにしている
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    return loss

U-Netの学習

Networkモデル、DataLoader、Loss Functionを実装できたので、いよいよ学習させていきます。
学習のコードは次のようになります。

from dataset import CustomImageDataset
from network import UNet
import torch
from torch.utils.data import DataLoader
import torchvision.models as models
import os
from losses import FocalLoss, calc_loss
import matplotlib.pyplot as plt


device = "cuda"

def train_loop(dataloader, model, loss_fn, opt):
    data_size = len(dataloader.dataset)

    for batch, D in enumerate(dataloader):
        X = D["image"].to(device)
        y = D["label"].to(device)
        opt.zero_grad()

        pred = model(X)
        loss = loss_fn(pred, y)

        # バックプロパゲーション
        loss.backward()
        opt.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch*len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{data_size:>5d}]")


if __name__ == "__main__":

    #loss_fn = FocalLoss(1)
    #loss_fn = torch.nn.CrossEntropyLoss()
    loss_fn = calc_loss
    model = UNet(2).to(device)
    #model = torch.load("model.pth")
    optimizer = torch.optim.Adam(model.parameters())
    train_data = CustomImageDataset()
    train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=os.cpu_count())
    epochs = 100
    import time
    start = time.time()
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)


    print("time: "+str(time.time()-start))

    torch.save(model, "model.pth")

DataLoader(train_data, batch_size=16, shuffle=True, num_workers=os.cpu_count()

ここのnum_workersでCPU数を指定する事で、DataLoaderを並列処理にすることができます。
Image Segmentationの様に画像が大きく読み込み時間がボトルネックになる場合や、計算量の多いData Augmentationを行う場合は並列処理にすることで大幅に実行時間を短縮する事ができます。

学習を開始するとGPUのファンが高速回転し始め、電気をモリモリ消費していきます。
数時間くらいで学習が終わるのですが、もっと大規模な学習を行った時の電気代が非常に気になるところです。

推論結果

学習後のモデルを使って色々な画像を推論してみました。
推論コードは次のようになります。

import torch
from torchvision.io import read_image
import sys
from torchvision import transforms as transforms
import numpy as np


model = torch.load("./model.pth")


def run(path):
    img = read_image(path).float().to("cuda")
    resize = transforms.Resize((320, 640))
    img = resize(img)

    out = model(img[None,])
    out = torch.nn.functional.sigmoid(out)

    out = out.detach().to("cpu")
    out = torch.argmax(out[0], axis=0)
    out = out.numpy()
    plt.imshow(out)
    plt.show()


if __name__ == "__main__":
    while True:
        path = input()[1:-2]
        run(path)

Cityscapesの評価画像を推論すると次のようになります。

入力画像 推論画像
f:id:nsr_9:20210905060003j:plain f:id:nsr_9:20210905060009j:plain
f:id:nsr_9:20210905060042j:plain f:id:nsr_9:20210905060052j:plain
f:id:nsr_9:20210905060102j:plain f:id:nsr_9:20210905060108j:plain

概ね正しく推論できていそうですね! また、Cityscapesにはdemo用のビデオがあるので、そちらに対しても推論してみました。
わかりやすくする為に推論結果を青色にして、入力画像に重ねてみました。
f:id:nsr_9:20210905062826j:plain

f:id:nsr_9:20210905062218g:plain

f:id:nsr_9:20210905062301g:plain

まとめ

PytorchでImage Segmentation手法であるU-Netを実装し、Cityscapes Datasetを用いて車両のバイナリSegmentationを学習、推論しました。 Image Segmentationは様々な画像処理技術へ応用が可能なので、今後はこの技術を使って色々作っていきたいと思います。

[1] Sudre, Carole H., et al. "Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations." Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2017. 240-248.