社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

pytorchのお勉強(4):ネットワークの学習

はじめに

pytorchのお勉強の続きです。 今回は読み込んだデータセットを用いてモデルを学習していきます。
pytorchでCNNの学習システムを開発する際の全体像を以下に示します。

f:id:nsr_9:20210826195056p:plain

今回勉強する所はピンク色の矩形で囲んだ部分です。

学習(最適化)

今回は最適化に関するチュートリアルを参考にします。

また、学習に用いるNetworkモデルとDatasetの実装に関しては、以前の記事を参照してください。

nsr-9.hatenablog.jp nsr-9.hatenablog.jp

必要な部分だけ抜粋します。

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
import glob
from torch.utils.data import DataLoader


class CustomImageDataset(Dataset):
    def __init__(self):
        self.negative_path = glob.glob("./data/non-vehicles/**/**/*.png")
        self.positive_path = glob.glob("./data/vehicles/**/**/*.png")
            
    def __len__(self):
        return len(self.negative_path) + len(self.positive_path)

    def __getitem__(self, idx):
        if idx < len(self.negative_path):
            img_path = self.negative_path[idx]
            label = torch.Tensor([1, 0])
        else:
            idx = idx - len(self.negative_path)
            img_path = self.positive_path[idx]
            label = torch.Tensor([0, 1])
        image = read_image(img_path)
        sample = {"image": image, "label": label}
        return sample

NetworkとDatasetが定義できたら次に、学習に用いるLoss関数とOptimizerを定義します。

Loss関数には様々なものがありますが、回帰タスクにはMean Squared Error、分類タスクではCross Entropy Lossの使用が推奨されています。

今回は、車両とそれ以外の2クラス分類である為、Cross Entropyを使います。

loss_fn = nn.CrossEntropyLoss()

Optimizerも様々なものがあるのですが、だいたい多くの場合はAdamでうまく行くことが多いので(研究者がこんな事言って良いのかはおいといて)、今回もAdamを推奨パラメータで使います。

opt = torch.optim.Adam()

これらを踏まえてTrain Loopを実装します。

from dataset import CustomImageDataset
from network import Net
import torch
from torch.utils.data import DataLoader

def train_loop(dataloader, model, loss_fn, opt):
    data_size = len(dataloader.dataset)

    for batch, D in enumerate(dataloader):
        X = D["image"]
        y = D["label"]
        pred = model(X)
        loss = loss_fn(pred, y)

        # バックプロパゲーション
        opt.zero_grad()
        loss.backward()
        opt.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch*len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{data_size:>5d}]")


if __name__ == "__main__":
    loss_fn = torch.nn.CrossEntropyLoss()
    model = Net()
    optimizer = torch.optim.Adam(model.parameters())
    train_data = CustomImageDataset()
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    epochs = 10

    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)

これを実行すると次のように学習が進んでいってることが確認できます。
f:id:nsr_9:20210826204754p:plain

せっかくなのでLoss値をグラフにしました。
f:id:nsr_9:20210826205944p:plain 急激に収束している事が確認できますね。

まとめ

今回はNetworkの学習方法について学びました。
このままだと学習しっぱなしでモデルが保存できないので、次回はモデルを保存して実際に推論する所を勉強していきたいと思います。