pytorchのお勉強(4):ネットワークの学習
はじめに
pytorchのお勉強の続きです。
今回は読み込んだデータセットを用いてモデルを学習していきます。
pytorchでCNNの学習システムを開発する際の全体像を以下に示します。
今回勉強する所はピンク色の矩形で囲んだ部分です。
学習(最適化)
今回は最適化に関するチュートリアルを参考にします。
また、学習に用いるNetworkモデルとDatasetの実装に関しては、以前の記事を参照してください。
nsr-9.hatenablog.jp nsr-9.hatenablog.jp
必要な部分だけ抜粋します。
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 1 input image channel, 6 output channels, 3x3 square convolution # kernel self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 2) def forward(self, x): # Max pooling over a (2, 2) window x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # all dimensions except the batch dimension num_features = 1 for s in size: num_features *= s return num_features
import os import torch from torch.utils.data import Dataset from torchvision.io import read_image import glob from torch.utils.data import DataLoader class CustomImageDataset(Dataset): def __init__(self): self.negative_path = glob.glob("./data/non-vehicles/**/**/*.png") self.positive_path = glob.glob("./data/vehicles/**/**/*.png") def __len__(self): return len(self.negative_path) + len(self.positive_path) def __getitem__(self, idx): if idx < len(self.negative_path): img_path = self.negative_path[idx] label = torch.Tensor([1, 0]) else: idx = idx - len(self.negative_path) img_path = self.positive_path[idx] label = torch.Tensor([0, 1]) image = read_image(img_path) sample = {"image": image, "label": label} return sample
NetworkとDatasetが定義できたら次に、学習に用いるLoss関数とOptimizerを定義します。
Loss関数には様々なものがありますが、回帰タスクにはMean Squared Error、分類タスクではCross Entropy Lossの使用が推奨されています。
今回は、車両とそれ以外の2クラス分類である為、Cross Entropyを使います。
loss_fn = nn.CrossEntropyLoss()
Optimizerも様々なものがあるのですが、だいたい多くの場合はAdamでうまく行くことが多いので(研究者がこんな事言って良いのかはおいといて)、今回もAdamを推奨パラメータで使います。
opt = torch.optim.Adam()
これらを踏まえてTrain Loopを実装します。
from dataset import CustomImageDataset from network import Net import torch from torch.utils.data import DataLoader def train_loop(dataloader, model, loss_fn, opt): data_size = len(dataloader.dataset) for batch, D in enumerate(dataloader): X = D["image"] y = D["label"] pred = model(X) loss = loss_fn(pred, y) # バックプロパゲーション opt.zero_grad() loss.backward() opt.step() if batch % 100 == 0: loss, current = loss.item(), batch*len(X) print(f"loss: {loss:>7f} [{current:>5d}/{data_size:>5d}]") if __name__ == "__main__": loss_fn = torch.nn.CrossEntropyLoss() model = Net() optimizer = torch.optim.Adam(model.parameters()) train_data = CustomImageDataset() train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True) epochs = 10 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train_loop(train_dataloader, model, loss_fn, optimizer)
これを実行すると次のように学習が進んでいってることが確認できます。
せっかくなのでLoss値をグラフにしました。
急激に収束している事が確認できますね。
まとめ
今回はNetworkの学習方法について学びました。
このままだと学習しっぱなしでモデルが保存できないので、次回はモデルを保存して実際に推論する所を勉強していきたいと思います。