社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

PytorchによるImage Segmentation(2)

はじめに

nsr-9.hatenablog.jp この記事の続きです。
Pytorchを用いてU-NetのImage Segmentationをやっていきます。

前回はU-Netのモデルを定義したので、今回はDataLoader部分を作っていきます。

Image Segmentation用のDataset

Image Segmentationを簡単に表現すると、CNNで自動的に色塗りをさせるようなイメージです。
車や人が写ってる領域や、医療画像の血管領域だけ色塗りさせたり等、かなり器用にタスクをこなしてくれます。

Image Segmentationを学習する際は、画像と色塗りした模範画像(ラベル)のペアを100枚から数万枚単位で用意します。 以下に入力データの参考例として、cityscapes dataset[1]のデータを載せます。

入力画像の参考例 模範画像の参考例
f:id:nsr_9:20210904103935p:plain f:id:nsr_9:20210904103942p:plain

学習に必要な枚数はターゲットとなるタスクの難しさによって増減しますが、大体1000枚あればそれなりの精度が出ます。
また、ラベル画像はアノテーションツールを用いて作成します。
数年前はLabelMeというアノテーションツールしかなかったのですが、昨今では様々なアノテーションツールが公開されているようですね。
セマンティック・セグメンテーション用アノテーションツール – demura.net

今回は、ラベル画像を位置から作るのではなく、オープンに公開されているCityscapes Datasetを加工して、Image Segmentation用のデータセットを作成しました。 Cityscapes Datasetは道路や空、街路樹、歩行者など様々な対象物を色分けしたMulti-classのSemantic Segmentation Datasetなのですが、車のみを対象としたImage Segmentation Datasetにします。

f:id:nsr_9:20210904105926p:plain

入力画像をimages/、ラベル画像をlabels/にした際のDataLoaderは次のようになりました。

import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, write_jpeg
import glob
from torch.utils.data import DataLoader
from torchvision import transforms as transforms
import numpy as np


class CustomImageDataset(Dataset):
    def __init__(self, train=True):

        self.train = train
        if train:
            self.image_path = glob.glob("./dataset/train/images/*.jpg")
            self.label_path = glob.glob("./dataset/train/labels/*.png")
        else:
            self.image_path = glob.glob("./dataset/val/images/*.jpg")
            self.label_path = glob.glob("./dataset/val/labels/*.png")

        self.image_path.sort()
        self.label_path.sort()

        self.gray = transforms.Grayscale()
        self.resize = transforms.Resize((320, 640))

    def __len__(self):
        return len(self.image_path)


    def __getitem__(self, idx):
        img_path = self.image_path[idx]
        label_path = self.label_path[idx]

        image = read_image(img_path)
        label = self.gray(read_image(label_path))

        # 車以外のラベルを0にする
        label[label != 25] = 0
        label[label == 25] = 1

        image = self.resize(image)
        label = self.resize(label)

        # ランダムで左右にフリップする
        if torch.randn(1) > 0:
            image = transforms.functional.hflip(image)
            label = transforms.functional.hflip(label)

        # ランダムで256x256をクロッピングする
        i, j, h, w = transforms.RandomCrop.get_params(image, (256, 256))
        image = transforms.functional.crop(image, i, j, h, w)
        label = transforms.functional.crop(label, i, j, h, w)

        label = torch.nn.functional.one_hot(label[0].long(), num_classes=2)
        label = torch.permute(label, (2, 0, 1))
        sample = {"image": image.float(), "label": label.float()}
        return sample


if __name__ == "__main__":
    bsize = 8
    train_data = CustomImageDataset()
    train_dataloader = DataLoader(train_data, batch_size=bsize, shuffle=True)

    data = next(iter(train_dataloader))
    images = data["image"]
    labels = data["label"]
    for i in range(bsize):
        image = images[i]
        label = labels[i]
        label = torch.stack([label[1], label[1], label[1]])
        print(label.size())
        write_jpeg(image.to(torch.uint8), "image_{0:01d}.jpg".format(i))
        write_jpeg(label.to(torch.uint8)*255, "label_{0:01d}.jpg".format(i))

このデータローダーを走らせると、次のような学習バッチが生成されます。

image label
f:id:nsr_9:20210904114931p:plain f:id:nsr_9:20210904114937p:plain

入力画像に対応したラベル画像が生成されていますね!

まとめ

今回はImage Segmentation Datasetを読み込むDataloaderを作成しました。
次回はU-Netを学習させるLoss Fuctionを実装し、学習を行ってみたいと思います。

[1] Cityscapes Dataset – Semantic Understanding of Urban Street Scenes