Pytorchでグレースケール画像の着色
はじめに
pytorchでGANをやっているのですが、乱数源から顔画像を生成するtutorialを繰り返すのも芸がないので、白黒画像(グレースケール)からカラー画像を復元するタスクをやろうと思いました。 今回は、GANで取り組む前にPix2Pixという教師ありの画像生成モデルで、グレースケール画像の着色をやっていきます。
今回 Pix2Pix | 次回 GAN |
---|---|
グレースケールの着色
今回はカラー画像を一旦グレースケールに変換し、グレースケール画像からカラー画像に戻す様なスキームでネットワークを学習します。
グレースケールの変換は、OpenCVの関数を使うと簡単にできます。
import cv2 rgb_img = cv2.imread("img.png") gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY)
カラー画像 | グレースケール |
---|---|
グレースケールを入力、カラー画像を正解のペアとして、U-Netを使って学習していきます。
前回のImage Segmentationで使ったU-Netは出力層がSoftmaxでしたが、今回は3chのチャンネル間で独立した輝度値(0以上の実数値)を出力するので、ちょっと改造します。
import torch import torch.nn as nn import torch.nn.functional as F class VGGBlock(nn.Module): def __init__(self, in_channels, middle_channels, out_channels): super(VGGBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(middle_channels) self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.activation = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.activation(x) x = self.conv2(x) x = self.bn2(x) x = self.activation(x) return x class DownSampling(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.pool1 = nn.MaxPool2d(2) self.vggblock1 = VGGBlock(in_channels, out_channels, out_channels) def forward(self, x): x = self.pool1(x) x = self.vggblock1(x) return x class UpSampling(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.vggblock = VGGBlock(in_channels, out_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2]) x = torch.cat([x2, x1], dim=1) return self.vggblock(x) class UNet(nn.Module): def __init__(self, classes): super(UNet, self).__init__() self.conv1 = VGGBlock(1, 64, 64) self.down1 = DownSampling(64, 128) self.down2 = DownSampling(128, 256) self.down3 = DownSampling(256, 512) self.down4 = DownSampling(512, 512) self.up1 = UpSampling(1024, 256) self.up2 = UpSampling(512, 128) self.up3 = UpSampling(256, 64) self.up4 = UpSampling(128, 64) self.out = nn.Conv2d(64, classes, kernel_size=1) self.act = nn.ReLU(inplace=True) def forward(self, x): x1 = self.conv1(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) x = self.out(x) x = self.act(x) return x
また、DataLoaderはGrayScale画像とカラー画像のペアを返すだけなので、非常にシンプルになります。
import os import torch from torch.utils.data import Dataset from torchvision.io import read_image, write_jpeg import glob from torch.utils.data import DataLoader from torchvision import transforms as transforms import matplotlib.pyplot as plt from PIL import Image import numpy as np class CustomImageDataset(Dataset): def __init__(self, train=True): self.train = train if train: self.image_path = glob.glob("./train/images/*.jpg") self.label_path = glob.glob("./train/labels/*.jpg") else: self.image_path = glob.glob("./test/images/*.jpg") self.label_path = glob.glob("./test/labels/*.jpg") self.image_path.sort() self.label_path.sort() def __len__(self): return len(self.image_path) def __getitem__(self, idx): img_path = self.image_path[idx] label_path = self.label_path[idx] image = read_image(img_path) label = read_image(label_path) # ランダムで左右にフリップする if torch.randn(1) > 0: image = transforms.functional.hflip(image) label = transforms.functional.hflip(label) # ランダムで256x256をクロッピングする i, j, h, w = transforms.RandomCrop.get_params(image, (128, 128)) image = transforms.functional.crop(image, i, j, h, w) label = transforms.functional.crop(label, i, j, h, w) sample = {"image": image.float(), "label": label.float()} return sample
train/imagesフォルダにグレースケール画像、train/labelsフォルダにそれと対応したカラー画像を入れると、DataLoaderがいい感じに読み込んでくれます。
これを30分くらい学習しました。
前回学習した際はデータ数が3000枚くらいだったのですが、今回は40000枚と10倍以上増えました。
学習した結果は次のようになりました。
左が入力画像のグレースケール、真ん中が推論画像、一番右が正解のカラー画像です。
入力画像 | 推論画像 | 正解画像 |
---|---|---|
所々「ん?」って感じの結果になっていますが、アスファルトや木など、直感的に色が想像しやすい対象物に関しては、思ったよりもちゃんと色復元されているように見えますね! 今回は、学習するデータセットのドメイン(背景知識みたいな意味)が比較的にはっきりしているので、それっぽい学習がされているように思えます。
今回はPix2Pixで学習しましたが、次回はこの色塗りタスクをGANでやっていきたいと思います。