PytorchによるGAN(1):全体像の把握
GAN
GAN(generative adversarial networks)[1]とは、2014年に登場したNeural Networkの学習スキームです。
一般的なNeural Networkは入力ベクトルとそれに対応する正解ベクトルの写像(対応)関係を回帰する学習を行うのですが、GANでは敵対的学習という画期的な学習方法を提案しています。
敵対的学習では、学習させたいNetwork(便宜上Generatorと呼びます)とDiscriminatorというGeneratorを学習させる為のNetworkの2つのNetworkを用います。
通常の学習 | GAN |
---|---|
Generatorは入力データに対して出力を返す、いわば普通のNetworkなのですが、DiscriminatorはGeneratorの出力値が満足の行くものなのか判定を下します。
Generatorは入力に対してどの様な出力を返せばよくわからないのですが、とりあえずDiscriminatorが満足する様な出力値を探索していき、それを繰り返していく内にだんだん””良い出力値””を推論できるようになっていきます。
例えば、以下の様な顔画像を生成するNetworkの学習を考えてみます。
画像生成Network(Generator)は、乱数入力xから顔っぽい画像(Fake画像)の生成を試みます。
Discriminatorは本物の顔画像(Real Data)と乱数より生成された生成された顔画像(Fake Data)を受け取り、真贋の判定を行います。
その際に、DiscriminatorはFake Dataに騙されないように学習を進めるわけですが、GeneratorはDiscriminatorを騙せるようによりRealっぽい画像を生成するように学習を進めます。
この様に、DiscriminatorとGeneratorを競わせるように学習させる事で、xと対をなす正解画像が無くても学習を進められます。
この技術が登場した時、非常にComputer Vision業界は盛り上がりました。
CVPR2016~2018では、体感で1/4くらいの研究でGANが扱われてた様な気がします。
今回から、この一世を風靡したGANをpytorchで実装し、遊んでみたいと思います。
全体像
今回は(今回も)pytorch公式のチュートリアル資料を参考にして進めていきます。 DCGAN Tutorial — PyTorch Tutorials 1.10.0+cu102 documentation
この資料ではGANから発展した技術であるDCGAN[2]を対象としています。
DCGANはDeep Convolutional Generative Adversarial Networksの略称で、畳み込みNeural NetworkによるGANとなっています。
DCGANのスキームを図にすると、次のようになります。
Generator部分はU-Net、Discriminator部分は普通の画像分類Networkで実装できそうな印象がありますね。
DiscriminatorのLoss関数はBinary Cross Entropyを設定し、OptimizerはDiscriminatorとGeneratorでそれぞれ異なるものを設定すれば良いみたいです。
criterion = nn.BCELoss() # GとDの両方にオプティマイザ:Adamを設定する optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
pytorchはOptimizerをかなり柔軟に設定できるなぁと感じました。
DCGANの全体像はなんとなく理解することができました。
次回から、実際に実装を行っていきたいと思います。
[1] Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems 27 (2014).
[2] Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).