社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

pythonで音声信号処理のことはじめ

はじめに

今までは画像処理や画像認識ばっかりやってきたのですが、ふと音声も画像と同じパターン認識技術の一つだなぁと思い、ここは一つ音声も触れるようになりたいなと思いました。
一昔前は音声信号処理を頑張るにはC言語の開発環境とライブラリ群をインストールするのが大変だと聞いていたのですが、昨今ではpython pip一発で環境整備ができるようになったと聞きます。
本当にpython様様ですね。
今回からpythonで音声信号処理をやってみたいと思います。

音声信号処理

Wikipediaを引用すると音声信号処理とは"音または音を表す信号を処理すること"を指すようです。
音響信号処理 - Wikipedia

パッと思いつく応用例としては、音声通話(電話, VoIP)、音声圧縮(MP3)、音声認識(Siri, 書き起こし)、音響技術(ハイレゾ, ノイズキャンセリング, ノイズ補正)等々でしょうか。
よくよく思い返すと非常に広い分野で社会に根付いた技術であるなぁと思いました。

とても広い応用先ですが、構成技術も非常に多く、体系的にイメージすることが難しそうですね。
以下に、今僕がイメージしている技術体系を図にしてみました。

f:id:nsr_9:20210919140415p:plain

音声技術についてはあまり詳しくないのですが、この中でも特に信号処理の部分がホットトピックスなんでしょうかね?
具体的にお勉強してみたい技術としては、ノイズキャンセリング、指向性マイク(ある方向やある種類の音を聞き取りやすくするマイク)、望遠マイク(遠くの音を聞き取る)、音声変換(いつも聞いてる楽曲を違うテイストに変換する)などです。

まだ全然調査できていないので間違っているかもしれませんが、多分このピンクの丸の部分が必要な技術になると思います。

f:id:nsr_9:20210919141107p:plain

準備したもの

暗中模索状態ですが、何はともあれ録音装置、音声処理用の演算装置、音声再生装置が必要である事は間違いないと思います。
音声処理用の演算装置は、汎用PC(Ubuntu)+Pythonで賄うとして、ポータビリティ性の高い録音装置、再生装置が必要です。
色々調べた結果、TASCAMのDR-05と言うリニアPCMレコーダーを購入しました。
DR-05 | 製品トップ | TASCAM (日本)

この装置は名前の通り乾電池で動く録音機なのですが、これが中々面白い機能が沢山ついています。

まず外観ですが、シンプルでとてもかっこいいです。
f:id:nsr_9:20210919144850p:plain
f:id:nsr_9:20210919145037p:plain

録音機単体で録音した音声の再生や簡単な音響処理(エコーとか多重録音とか)ができます。
裏面にはマウント用のネジ(カメラマウントと互換)と電池ボックスが見えます。

f:id:nsr_9:20210919145254p:plain
こんな感じの三脚を用意すれば簡単に固定することができます。
f:id:nsr_9:20210919145427j:plain
f:id:nsr_9:20210919145436j:plain

電池ボックスには単3電池が2本入ります。
f:id:nsr_9:20210919145643j:plain
充電式も好きですが、 出先で電池切れになった時にコンビニに駆け込めばすぐに動かすことができるので、カートリッジ式もとても便利ですね。

側面にはマイクロSDカードスロットとφ3.5mmのイヤフォンジャックがついています。
f:id:nsr_9:20210919145930j:plain
f:id:nsr_9:20210919145937j:plain

イヤフォンジャックをつけなくても録音やスピーカー再生する事ができます。

最後にフロント部分にはL/Rのステレオマイクとステレオの外部マイク用のジャックがついています。 f:id:nsr_9:20210919150048j:plain

簡単に録音した音声を聞きましたが、非常にクリアに聞こえてとてもビックリしました。
普段、Web会議やボイスチャットでは安物のマイクを使ってたのですが、高性能なモノはここまで凄いのか!!と童心に帰って色々な音声を録音しては再生するという遊びをしてしまいました。

また、L/Rのステレオマイク構成なので、昨今話題の立体音響録音などができます。
こんな感じにレコーダーの周辺をペンでコンコン叩いてみました。
f:id:nsr_9:20210919153638p:plain

いい感じに立体音響で録音することができました。

終わりに

今回は、音声信号処理を勉強するためにレコーダーを購入し、色々遊んでみました。
次回からは録音したwavファイルをpythonで扱ってみたいと思います。

CHWとHWCの相互変換

やりたいこと

opencvの画像は縦(Height)×横(Width)×チャンネル(Channel)のいわゆるHWC形式になっています。
それとは異なり、pytorchの画像ではチャンネル(Channel)×縦(Height)×横(Width)のCHW形式になっています。

numpy arrayとtorch tensorにおいて、HWCとCHWの相互変換方法を示します。

Numpy Arrayでの軸変換

np.transposeで軸を変換します。

import numpy as np
H, W, C = 128, 256, 3
img = np.random.random([H, W, C])    # H, W, C
print(img.shape)

img = np.transpose(img, (2, 0, 1))        # C, H, W
print(img.shape)

実行すると次のような出力を得ます。

(128, 256, 3)  
(3, 128, 256)

ちゃんと軸変換ができました!

Pytorch Tensorでの軸変換

torch.permuteで軸変換ができます。

import torch
H, W, C = 128, 256, 3
img = torch.randn((C, H, W))         # C, H, W
print(img.size())
img = torch.permute(img, (1, 2, 0)) # H, W, C
print(img.size())

実行すると次のような出力を得ます

torch.Size([3, 128, 256])
torch.Size([128, 256, 3])

こちらもちゃんと軸変換できました!

液体レンズの試作(2)

はじめに

前回、液体レンズの仕組みと基本構成について考えてみました。

nsr-9.hatenablog.jp

今回から実際に作っていきたいと思います。

材料

今回の簡易液体レンズは以下の要素で構成されます。

要素 役割 主な素材
透明シート レンズ面の構成 透明塩ビシート, サランラップ
透明液体 レンズ 水, 高分子ポリマー等
リング 容器の密閉 合成ゴム, プラスティック等
シリンジ 液体の注入 注射器のアレ

これらの材料は散歩がてらに近所のホームセンターを散策し、それっぽいものを見繕いました。

透明シート

実際にはホームセンターに売ってた300円くらいの20cm×20cmのシートを買いました
f:id:nsr_9:20210914212115p:plain 引用: 透明ビニールシート 0.5mm厚×幅915mm×7m切 ロール巻き RoHS10対応 オカモトマジキリセレブ :sheet2020050901:上村シート ヤフー店 - 通販 - Yahoo!ショッピング

透明流体

お家の水道水をそのまま使いました

リング

30cm×30cmの合成ゴムの板をホームセンターで買いました
f:id:nsr_9:20210914212138p:plain
引用: NBR合成ゴムシート 15mm厚×1M幅×長さカット自由 (10cm単位) の通販 | 資材調達支援サイトGAOS(ガオス)

シリンジ

シリンジは家庭菜園コーナーにある注射器、チューブはφ5mmの透明チューブを買いました。

f:id:nsr_9:20210914212716p:plain

f:id:nsr_9:20210914212725p:plain

工具類

また、加工を行うために別途、工具も調達しました。

工具 用途
円形カッター 合成ゴムのカット
セメダイン 透明シートの接着
ハンドドリル 注入孔の穴あけ

実際に作ったもの

設計図(?)を参考に、何個か作ってみました。
f:id:nsr_9:20210915124829j:plain

f:id:nsr_9:20210915134448p:plain

f:id:nsr_9:20210915134519j:plain

透明な液体を入れると次のようになります。
f:id:nsr_9:20210915185208j:plain

シリンジをプッシュすると次のように膨らみます。
f:id:nsr_9:20210915185756j:plain ちょっと分かりづらいので、もう一つ載せます。

f:id:nsr_9:20210915190203p:plain
透明シートの輪郭をピンク色でなぞりました。
レンズのような形状になっていることが確認できますね!

この液体レンズを通して色々観察してみました。

プッシュ前 プッシュ後 比較
f:id:nsr_9:20210915190658p:plain f:id:nsr_9:20210915190707p:plain f:id:nsr_9:20210915190848g:plain

プッシュをすると写っている像が拡大されている(虫眼鏡のようになっている)事が確認できます。

f:id:nsr_9:20210915191406g:plain

最後に、液体レンズの簡単な説明を動画化しました。

www.youtube.com

まとめ

力学的な力に基づく液体レンズを試作しました。
思ったよりもきれいな像が得られてとても満足しました。
このレンズは頑張れば任意の焦点距離に設定できるので、なにかおもしろいことができそうな予感がします。

液体レンズの試作(1)

はじめに

液体レンズとは、透明な液体で構成されたレンズの事で、電圧や力学的な力を加える事で屈折率を変えるレンズです。
少ないエネルギーで焦点距離を自由に変更できる為、産業用カメラやスマートフォン、次世代のメガネ等々への応用が期待される面白い技術です。

www.youtube.com

液体レンズの研究の中でも特に興味深い研究は石川研究室の、ダイナモルフレンズです。

www.youtube.com

圧電素子と呼ばれる電気を流すと振動する素子を用いて、高速に焦点距離(液体の曲率)を変更できる液体レンズを試作しています。
また、焦点距離の高速調整機能を使って次のような衝撃的なデモを行っていました。

www.vision.ict.e.titech.ac.jp

今回は、この液体レンズを作ってみたいと思います。

レンズの仕組み

まず、一般的なレンズについて考えてみます。
高校物理の単元では、薄肉レンズモデルという幾何光学の中でも最も基本的なレンズモデルについて習うと思います。
薄肉レンズ(以下レンズ)モデルは、以下のように、光源(左のろうそく)、レンズ、投影面と像(右のろうそく)で構成されます。

f:id:nsr_9:20210913194736p:plain

このモデルで最も重要な部分は、焦点距離と物体の距離、結像位置(ピントが合う位置)は一つの式で関係づけられている事です。
この式をガウスのレンズ公式といいます。

ガウスのレンズ公式から、物体の距離が変わると結像位置が変わることがわかります。
以下に、ちょっと正確ではないのですが、焦点距離fと投影面bの位置を変えずに、物体の距離aだけ変えて撮像した図を載せます。

f:id:nsr_9:20210913195945p:plain

上段は正しく結像している状態の図で、下段がその状態から物体の距離を変えた時の模式図です。
投影面と結像位置が異なっている為、ぼやけた絵になります。
スマートフォンとかで近い距離と遠い距離が混在するシーンを撮影すると、この様な焦点ボケを観察する事ができます。

f:id:nsr_9:20210913200523p:plain

一般的なカメラのオートフォーカス機能は、レンズと投影面の距離をアクチュエータ等で変更する事で、様々な距離に距離の物体にピントを合わせる事ができます。
【カメラ用語事典】コントラストAF | CAPA CAMERA WEB

理論的にはとても素直なのですが、レンズと投影面の距離を物理的に変える必要がある為、(精密機器としては)大きなエネルギーが必要となり、ミリ秒単位での制御が難しいという課題があります。
それに対し、液体レンズではレンズの焦点距離fを変える事で、オートフォーカスを実現しています。
物理的な移動が伴わないため、非常に応答速度が早いという特長があります。

液体レンズ

液体レンズは、液体の曲率を変える事で焦点距離を変更するレンズです。
液体レンズの構成方法は様々なものがありますが、今回はその中でも最もシンプルなものを考えてみました。

f:id:nsr_9:20210913201542p:plain

まず、ゴム製のリングの両面を伸縮性のある透明なシートで密閉します。
次に、リングの側面に開けた注入孔からシリンジを用いて透明な流体を注入します。

この流体を十分に注入すると両面の透明シートが膨らむのですが、この際に透明シートは半球面を形成します。
レンズの焦点距離は半球の曲率(曲率半径)によって定まるので、流体の充填率によって調整する事ができます。

非常に簡単な仕組みですが、これで基本的な液体レンズを構成する事ができます。

次回は、この設計図(?)に従って液体レンズを試作していきたいと思います。

Image Registrationによる外観検査

はじめに

前回、位相限定相関法によるImage Registration手法で遊んでみました。
nsr-9.hatenablog.jp

Image Registrationの一つの応用先として、外観検査装置があります。
外観検査装置は、工場のラインで流れてくる製品を高速度カメラで撮影し、傷や付着物などの品質不良が無いか自動的に判別するシステムです。

以下の動画の6:40あたりから外観検査装置が登場しています。
youtu.be

外観検査装置は画像処理技術が最も早く社会実装された応用先だと思います。
今回は、pythonOpenCVを用いて、簡単な外観検査アルゴリズムを作ってみたいと思います。

画像処理による外観検査

外観検査のやり方は多種多様なものがありますが、その中でも最も簡単なものはテンプレートの差分を求める手法だと思います。
逐次ラインから流れてくる製品をカメラで撮影し、品質不良の無い理想のテンプレート画像と比較する事で異常を検出します。

f:id:nsr_9:20210911193243p:plain

テンプレートと比較する際は、一般的に画像差分を行います。
シンプルに実装すると次のようになります。

import cv2
img = cv2.imread("img.jpg")
tmp = cv2.imread("tmp.jpg")

cv2.asbdiff(img, tmp)

しかしながら、そのままこれで実行しようとすると、製品の画像がずれていた場合、悲しいことになってしまいます。

画像1 画像2 差分画像
f:id:nsr_9:20210911193706j:plain f:id:nsr_9:20210911193713j:plain f:id:nsr_9:20210911193745j:plain

差分画像について、色がついている部分がテンプレートと異なる領域なのですが、特に異常が無いにもかかわらず大部分が異なる領域であると出力されています。

これだと異常検知ができないので、前回遊んだImage Registrationを適用します。

import cv2
import sys
import numpy as np

if __name__ == "__main__":
    img1 = cv2.imread(sys.argv[1])
    img2 = cv2.imread(sys.argv[2])

    gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32)
    gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32)

    (x, y), r = cv2.phaseCorrelate(gray1, gray2)
    print(x, y, r)

    W = img1.shape[1]
    H = img1.shape[0]

    if x < 0:
        x11, x12 = abs(int(x)), W
        x21, x22 = 0, W-abs(int(x))
    else:
        x21, x22 = abs(int(x)), W
        x11, x12 = 0, W-abs(int(x))
    
    if y < 0:
        y11, y12 = abs(int(y)), H
        y21, y22 = 0, H-abs(int(y))
    else:
        y21, y22 = abs(int(y)), H
        y11, y12 = 0, H-abs(int(y))
    

    print(W, x12-x11, x22-x21)
    cv2.imwrite("out1.png", img1[y11:y12, x11:x12])
    cv2.imwrite("out2.png", img2[y21:y22, x21:x22])

Image Registrationを行った後に差分をとると次のようになります。

画像1 画像2 差分画像
f:id:nsr_9:20210911194326p:plain f:id:nsr_9:20210911194400p:plain f:id:nsr_9:20210911194415p:plain

Image Registration後の差分画像を見るとほとんどの領域の輝度値が0に近くなっている事がわかります。

では、いよいよ異常検知をやってみます。
以下の画像の様に、異常データとして黒のマジックでXマークをつけてみました。

f:id:nsr_9:20210911194524j:plain

画像1 画像2 差分画像
f:id:nsr_9:20210911195259p:plain f:id:nsr_9:20210911195314p:plain f:id:nsr_9:20210911195332p:plain

ちょっと分かりにくいので、差分画像を拡大してみました。

f:id:nsr_9:20210911195455p:plain

いい感じにXマークの部分が強調されていますね!

まとめ

Image Registrationと差分法で簡単な外観検査装置(の一部の機能)を作ってみました。
外観検査装置はさまざまな事に応用できるはずなので、これを使って色々作ってみたいと思います。

位相限定相関法による画像の位置合わせ(Image Registration)

はじめに

画像の位置合わせ(Image Registration)とは、異なる視点で撮影された2枚の画像の位置をいい感じにフィッティングする事です。
以下にサーベイ論文[1]にわかりやすい画像があったので、参照させてもらいます。
f:id:nsr_9:20210910163637p:plain

左上と右上の画像について、それぞれの対応関係(図中の+1~+6)を求め、左下の図のようにピッタリと位置が合うように画像変換を行います。
位置が合うように画像変換を行うことで、バラバラに撮影された画像からパノラマ画像のようなものが作成できています。

Image Registrationは、外観検査装置や指紋、虹彩認証、3D-Sensing、物体追跡等々、様々なアプリケーションに応用される汎用的な技術となっています。
今回は、Image Registration手法の中でも特に実用性が高いと言われている、位相限定相関法[2]をPythonで実装します。

位相限定相関法

位相限定相関法は、周波数変換を行った画像の位相成分のみを用いる事で、サブピクセル精度でのImage Registrationを実現しています。
画像はピクセルという最小単位の粒で構成されているのですが、サブピクセル単位で位置調整するという事は、その最小単位以下(例えば0.5ピクセル)の精度で位置合わせができるということです。
凄いですね!

f:id:nsr_9:20210910165736p:plain

位相限定相関法を使うだけならば、OpenCVを使えば簡単にできます。

import cv2
import sys
import numpy as np

if __name__ == "__main__":
    img1 = cv2.imread(sys.argv[1])
    img2 = cv2.imread(sys.argv[2])

    gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32)
    gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32)

    (x, y), r = cv2.phaseCorrelate(gray1, gray2)
    print(x, y, r)

cv2.phaseCorrelateは2つの入力に対して、位相限定相関法で位置合わせを行う手法になっています。
入力画像はグレースケールのfloat32かfloat64の入力データになります。

スマホで撮った以下の2枚の画像に対して適応してみます。

入力画像1 入力画像2 重ねた画像
f:id:nsr_9:20210910173140j:plain f:id:nsr_9:20210910173147j:plain f:id:nsr_9:20210910173338g:plain

この画像に対して位相限定相関法を適用すると、次のような出力を得ます。

xのズレ量: -103.03237594382273
yのズレ量: 8.51607273664581
感度: 0.19513171376671812

感度は算出した値の信頼度のようなものです。
この結果を用いて入力画像をトリミングするコードは次のようになります。

    W = img1.shape[1]
    H = img1.shape[0]

    if x < 0:
        x11, x12 = abs(int(x)), W
        x21, x22 = 0, W-abs(int(x))
    else:
        x21, x22 = abs(int(x)), W
        x11, x12 = 0, W-abs(int(x))
    
    if y < 0:
        y11, y12 = abs(int(y)), H
        y21, y22 = 0, H-abs(int(y))
    else:
        y21, y22 = abs(int(y)), H
        y11, y12 = 0, H-abs(int(y))
    

    print(W, x12-x11, x22-x21)
    cv2.imwrite("out1.png", img1[y11:y12, x11:x12])
    cv2.imwrite("out2.png", img2[y21:y22, x21:x22])

実行すると次のような画像が生成されます。

出力画像1 出力画像2 重ねた画像
f:id:nsr_9:20210910174237p:plain f:id:nsr_9:20210910174248p:plain f:id:nsr_9:20210910174410g:plain

概ね位置があっていますね! 今回は、手に持ったスマートフォンで2枚の画像を撮影しているため、6軸の自由度(X,Y,Zの並進、パン, チルト, ロールの回転)を持っています。
f:id:nsr_9:20210910175321p:plain

位相限定相関法は、X, Y軸の並進移動(もしくはロール回転)のみに対応している為、重ねた結果がちょっとズレているのだと思います。

まとめ

今回はOpenCVに実装された位相限定相関法を使い、Image Registrationをやってみました。
かなり適当に撮った画像でも位置合わせを行えていたので、前処理や撮影環境を整備すればそのままでも実応用ができそうだと感じました。

[1] Zitova, Barbara, and Jan Flusser. "Image registration methods: a survey." Image and vision computing 21.11 (2003): 977-1000.
[2] 青木孝文, et al. "位相限定相関法に基づく高精度マシンビジョン―ピクセル分解能の壁を越える画像センシング技術を目指して―." 電子情報通信学会 基礎・境界ソサイエティ Fundamentals Review 1.1 (2007): 1_30-1_40.

Pytorchでグレースケール画像の着色

はじめに

pytorchでGANをやっているのですが、乱数源から顔画像を生成するtutorialを繰り返すのも芸がないので、白黒画像(グレースケール)からカラー画像を復元するタスクをやろうと思いました。 今回は、GANで取り組む前にPix2Pixという教師ありの画像生成モデルで、グレースケール画像の着色をやっていきます。

今回 Pix2Pix 次回 GAN
f:id:nsr_9:20210908155838p:plain f:id:nsr_9:20210908155849p:plain

グレースケールの着色

今回はカラー画像を一旦グレースケールに変換し、グレースケール画像からカラー画像に戻す様なスキームでネットワークを学習します。
グレースケールの変換は、OpenCVの関数を使うと簡単にできます。

import cv2
rgb_img = cv2.imread("img.png")
gray_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2GRAY)
カラー画像 グレースケール
f:id:nsr_9:20210908162413j:plain f:id:nsr_9:20210908162422j:plain

グレースケールを入力、カラー画像を正解のペアとして、U-Netを使って学習していきます。
前回のImage Segmentationで使ったU-Netは出力層がSoftmaxでしたが、今回は3chのチャンネル間で独立した輝度値(0以上の実数値)を出力するので、ちょっと改造します。

import torch
import torch.nn as nn
import torch.nn.functional as F



class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(VGGBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.activation(x)
        return x

class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool1 = nn.MaxPool2d(2)
        self.vggblock1 = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x):
        x = self.pool1(x)
        x = self.vggblock1(x)
        return x

class UpSampling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.vggblock = VGGBlock(in_channels, out_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
        x = torch.cat([x2, x1], dim=1)

        return self.vggblock(x)


class UNet(nn.Module):

    def __init__(self, classes):
        super(UNet, self).__init__()

        self.conv1 = VGGBlock(1, 64, 64)
        self.down1 = DownSampling(64, 128)
        self.down2 = DownSampling(128, 256)
        self.down3 = DownSampling(256, 512)
        self.down4 = DownSampling(512, 512)

        self.up1 = UpSampling(1024, 256)
        self.up2 = UpSampling(512, 128)
        self.up3 = UpSampling(256, 64)
        self.up4 = UpSampling(128, 64)
        self.out = nn.Conv2d(64, classes, kernel_size=1)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.out(x)
        x = self.act(x)
        return x

また、DataLoaderはGrayScale画像とカラー画像のペアを返すだけなので、非常にシンプルになります。

import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image, write_jpeg
import glob
from torch.utils.data import DataLoader
from torchvision import transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np


class CustomImageDataset(Dataset):
    def __init__(self, train=True):

        self.train = train
        if train:
            self.image_path = glob.glob("./train/images/*.jpg")
            self.label_path = glob.glob("./train/labels/*.jpg")
        else:
            self.image_path = glob.glob("./test/images/*.jpg")
            self.label_path = glob.glob("./test/labels/*.jpg")

        self.image_path.sort()
        self.label_path.sort()

    def __len__(self):
        return len(self.image_path)


    def __getitem__(self, idx):
        img_path = self.image_path[idx]
        label_path = self.label_path[idx]

        image = read_image(img_path)
        label = read_image(label_path)

        # ランダムで左右にフリップする
        if torch.randn(1) > 0:
            image = transforms.functional.hflip(image)
            label = transforms.functional.hflip(label)

        # ランダムで256x256をクロッピングする
        i, j, h, w = transforms.RandomCrop.get_params(image, (128, 128))
        image = transforms.functional.crop(image, i, j, h, w)
        label = transforms.functional.crop(label, i, j, h, w)
        sample = {"image": image.float(), "label": label.float()}
        return sample

train/imagesフォルダにグレースケール画像、train/labelsフォルダにそれと対応したカラー画像を入れると、DataLoaderがいい感じに読み込んでくれます。

f:id:nsr_9:20210908165211p:plain

これを30分くらい学習しました。
前回学習した際はデータ数が3000枚くらいだったのですが、今回は40000枚と10倍以上増えました。

学習した結果は次のようになりました。
左が入力画像のグレースケール、真ん中が推論画像、一番右が正解のカラー画像です。

入力画像 推論画像 正解画像
f:id:nsr_9:20210908214036j:plain f:id:nsr_9:20210908214307j:plain f:id:nsr_9:20210908214043j:plain
f:id:nsr_9:20210908214314j:plain f:id:nsr_9:20210908214323j:plain f:id:nsr_9:20210908214330j:plain
f:id:nsr_9:20210908214339j:plain f:id:nsr_9:20210908214409j:plain f:id:nsr_9:20210908214530j:plain
f:id:nsr_9:20210908214347j:plain f:id:nsr_9:20210908214416j:plain f:id:nsr_9:20210908214437j:plain
f:id:nsr_9:20210908214353j:plain f:id:nsr_9:20210908214423j:plain f:id:nsr_9:20210908214443j:plain

所々「ん?」って感じの結果になっていますが、アスファルトや木など、直感的に色が想像しやすい対象物に関しては、思ったよりもちゃんと色復元されているように見えますね! 今回は、学習するデータセットドメイン(背景知識みたいな意味)が比較的にはっきりしているので、それっぽい学習がされているように思えます。

今回はPix2Pixで学習しましたが、次回はこの色塗りタスクをGANでやっていきたいと思います。