社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

Pytorchでグレースケール画像の着色

はじめに

pytorchでGANをやっているのですが、乱数源から顔画像を生成するtutorialを繰り返すのも芸がないので、白黒画像(グレースケール)からカラー画像を復元するタスクをやろうと思いました。 今回は、GANで取り組む前にPix2Pixという教師ありの画像生成モデルで、グレースケール画像の着色をやっていきます。

今回 Pix2Pix 次回 GAN
f:id:nsr_9:20210908155838p:plain f:id:nsr_9:20210908155849p:plain

グレースケールの着色

今回はカラー画像を一旦グレースケールに変換し、グレースケール画像からカラー画像に戻す様なスキームでネットワークを学習します。
グレースケールの変換は、OpenCVの関数を使うと簡単にできます。

import cv2
rgb_img = cv2.imread("img.png")
gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY)
カラー画像 グレースケール
f:id:nsr_9:20210908162413j:plain f:id:nsr_9:20210908162422j:plain

グレースケールを入力、カラー画像を正解のペアとして、U-Netを使って学習していきます。
前回のImage Segmentationで使ったU-Netは出力層がSoftmaxでしたが、今回は3chのチャンネル間で独立した輝度値(0以上の実数値)を出力するので、ちょっと改造します。

import torch
import torch.nn as nn
import torch.nn.functional as F



class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(VGGBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.activation(x)
        return x

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool1 = nn.MaxPool2d(2)
        self.vggblock1 = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x):
        x = self.pool1(x)
        x = self.vggblock1(x)
        return x

class UpSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.vggblock = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
        x = torch.cat([x2, x1], dim=1)

        return self.vggblock(x)


class UNet(nn.Module):

    def __init__(self, classes):
        super(UNet, self).__init__()

        self.conv1 = VGGBlock(1, 64, 64)
        self.down1 = DownSampling(64, 128)
        self.down2 = DownSampling(128, 256)
        self.down3 = DownSampling(256, 512)
        self.down4 = DownSampling(512, 512)

        self.up1 = UpSampling(1024, 256)
        self.up2 = UpSampling(512, 128)
        self.up3 = UpSampling(256, 64)
        self.up4 = UpSampling(128, 64)
        self.out = nn.Conv2d(64, classes, kernel_size=1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.out(x)
        x = self.act(x)
        return x

また、DataLoaderはGrayScale画像とカラー画像のペアを返すだけなので、非常にシンプルになります。

import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, write_jpeg
import glob
from torch.utils.data import DataLoader
from torchvision import transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np


class CustomImageDataset(Dataset):
    def __init__(self, train=True):

        self.train = train
        if train:
            self.image_path = glob.glob("./train/images/*.jpg")
            self.label_path = glob.glob("./train/labels/*.jpg")
        else:
            self.image_path = glob.glob("./test/images/*.jpg")
            self.label_path = glob.glob("./test/labels/*.jpg")

        self.image_path.sort()
        self.label_path.sort()

    def __len__(self):
        return len(self.image_path)


    def __getitem__(self, idx):
        img_path = self.image_path[idx]
        label_path = self.label_path[idx]

        image = read_image(img_path)
        label = read_image(label_path)

        # ランダムで左右にフリップする
        if torch.randn(1) > 0:
            image = transforms.functional.hflip(image)
            label = transforms.functional.hflip(label)

        # ランダムで256x256をクロッピングする
        i, j, h, w = transforms.RandomCrop.get_params(image, (128, 128))
        image = transforms.functional.crop(image, i, j, h, w)
        label = transforms.functional.crop(label, i, j, h, w)
        sample = {"image": image.float(), "label": label.float()}
        return sample

train/imagesフォルダにグレースケール画像、train/labelsフォルダにそれと対応したカラー画像を入れると、DataLoaderがいい感じに読み込んでくれます。

f:id:nsr_9:20210908165211p:plain

これを30分くらい学習しました。
前回学習した際はデータ数が3000枚くらいだったのですが、今回は40000枚と10倍以上増えました。

学習した結果は次のようになりました。
左が入力画像のグレースケール、真ん中が推論画像、一番右が正解のカラー画像です。

入力画像 推論画像 正解画像
f:id:nsr_9:20210908214036j:plain f:id:nsr_9:20210908214307j:plain f:id:nsr_9:20210908214043j:plain
f:id:nsr_9:20210908214314j:plain f:id:nsr_9:20210908214323j:plain f:id:nsr_9:20210908214330j:plain
f:id:nsr_9:20210908214339j:plain f:id:nsr_9:20210908214409j:plain f:id:nsr_9:20210908214530j:plain
f:id:nsr_9:20210908214347j:plain f:id:nsr_9:20210908214416j:plain f:id:nsr_9:20210908214437j:plain
f:id:nsr_9:20210908214353j:plain f:id:nsr_9:20210908214423j:plain f:id:nsr_9:20210908214443j:plain

所々「ん?」って感じの結果になっていますが、アスファルトや木など、直感的に色が想像しやすい対象物に関しては、思ったよりもちゃんと色復元されているように見えますね! 今回は、学習するデータセットドメイン(背景知識みたいな意味)が比較的にはっきりしているので、それっぽい学習がされているように思えます。

今回はPix2Pixで学習しましたが、次回はこの色塗りタスクをGANでやっていきたいと思います。