社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

PytorchによるGAN(1):全体像の把握

GAN

GAN(generative adversarial networks)[1]とは、2014年に登場したNeural Networkの学習スキームです。
一般的なNeural Networkは入力ベクトルとそれに対応する正解ベクトルの写像(対応)関係を回帰する学習を行うのですが、GANでは敵対的学習という画期的な学習方法を提案しています。
敵対的学習では、学習させたいNetwork(便宜上Generatorと呼びます)とDiscriminatorというGeneratorを学習させる為のNetworkの2つのNetworkを用います。

通常の学習 GAN
f:id:nsr_9:20210906174812p:plain f:id:nsr_9:20210906174819p:plain

Generatorは入力データに対して出力を返す、いわば普通のNetworkなのですが、DiscriminatorはGeneratorの出力値が満足の行くものなのか判定を下します。
Generatorは入力に対してどの様な出力を返せばよくわからないのですが、とりあえずDiscriminatorが満足する様な出力値を探索していき、それを繰り返していく内にだんだん””良い出力値””を推論できるようになっていきます。

例えば、以下の様な顔画像を生成するNetworkの学習を考えてみます。

f:id:nsr_9:20210906174923p:plain

画像生成Network(Generator)は、乱数入力xから顔っぽい画像(Fake画像)の生成を試みます。
Discriminatorは本物の顔画像(Real Data)と乱数より生成された生成された顔画像(Fake Data)を受け取り、真贋の判定を行います。
その際に、DiscriminatorはFake Dataに騙されないように学習を進めるわけですが、GeneratorはDiscriminatorを騙せるようによりRealっぽい画像を生成するように学習を進めます。
この様に、DiscriminatorとGeneratorを競わせるように学習させる事で、xと対をなす正解画像が無くても学習を進められます。

この技術が登場した時、非常にComputer Vision業界は盛り上がりました。
CVPR2016~2018では、体感で1/4くらいの研究でGANが扱われてた様な気がします。

今回から、この一世を風靡したGANをpytorchで実装し、遊んでみたいと思います。

全体像

今回は(今回も)pytorch公式のチュートリアル資料を参考にして進めていきます。 DCGAN Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation

この資料ではGANから発展した技術であるDCGAN[2]を対象としています。
DCGANはDeep Convolutional Generative Adversarial Networksの略称で、畳み込みNeural NetworkによるGANとなっています。

DCGANのスキームを図にすると、次のようになります。
f:id:nsr_9:20210906195107p:plain

Generator部分はU-Net、Discriminator部分は普通の画像分類Networkで実装できそうな印象がありますね。

nsr-9.hatenablog.jp

nsr-9.hatenablog.jp

DiscriminatorのLoss関数はBinary Cross Entropyを設定し、OptimizerはDiscriminatorとGeneratorでそれぞれ異なるものを設定すれば良いみたいです。

criterion = nn.BCELoss()

# GとDの両方にオプティマイザ:Adamを設定する
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

pytorchはOptimizerをかなり柔軟に設定できるなぁと感じました。

DCGANの全体像はなんとなく理解することができました。
次回から、実際に実装を行っていきたいと思います。

[1] Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).
[2] Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).