pytorchのお勉強(5):モデルの保存と推論
はじめに
pytorchのお勉強の続きです。
今回は学習したモデルの保存と、そのモデルを読み込み実際に推論を行います。
学習モデルの保存と読み込み
今回はモデルの保存と読み込みに関するチュートリアルを参考にします。
また、学習に使うモデル等々は以前の記事を参照してください。
モデルを保存する際のフォーマットには,pytorch標準形式とONNX形式の二種類あります。
pytorchだけで学習、評価、推論を行う場合には標準のフォーマットで全く不都合が無いのですが、他のDeepLearningライブラリへモデルを流用したり、OpenCV等で活用する場合はONNX形式を用いると良いと思います。
pytorch標準形式を用いると、以下のコードで簡単にモデルの保存、読み込みができます。
torch.save(model, "model.pth") # モデルの保存 model = torch.load("model.pth") # モデルの読み込み
またONNX形式でモデルの保存する場合は以下の様に行います。
import torch.onnx as onnx input_image = torch.zeros((1, channel_size, height, width)) onnx.export(model, input_image, "model.onnx") # モデルの保存 # 読み込みは各種ライブラリのLoaderを用いる
モデルの読み込みと評価
モデルの保存、読み込みができるようになったので、いよいよ学習済みモデルを用いて実際に推論していきます。
今回のお勉強では、有志の方が公開している車両データセットを用いています。
このデータセットは64x64[pix]の車両背面画像(8500枚)とそれ以外の画像(8900枚)で構成されています。
適当に25枚ピックアップすると次のようになります。
※ライセンス表記がなかったので、問題がありそうならば下記の画像は削除します。
vehicle | non-vehicle |
---|---|
このデータセットの90%のデータを学習に、10%のデータを評価に使用しました。
モード | 学習枚数 | 評価枚数 |
---|---|---|
vehicle | 7650 | 850 |
non-vehicle | 8010 | 890 |
TrainとTestで異なるデータセット(のサブセット)を作成する為に、CustomDataset classを次のようにカスタマイズしました。
import os import torch from torch.utils.data import Dataset from torchvision.io import read_image import glob from torch.utils.data import DataLoader from torchvision import transforms as transforms class CustomImageDataset(Dataset): def __init__(self, train=True): self.negative_path = glob.glob("./data/non-vehicles/**/**/*.png") self.positive_path = glob.glob("./data/vehicles/**/**/*.png") self.resize = transforms.Resize(32) if train: # Train=Trueの時はデータセットの先頭90%を読み込む self.negative_path = self.negative_path[:int(len(self.negative_path)*0.9)] self.positive_path = self.positive_path[:int(len(self.positive_path)*0.9)] else: # Train=Falseの時はデータセットの後半10%を読み込む self.negative_path = self.negative_path[int(len(self.negative_path)*0.9):] self.positive_path = self.positive_path[int(len(self.positive_path)*0.9):] def __len__(self): return len(self.negative_path) + len(self.positive_path) def __getitem__(self, idx): if idx < len(self.negative_path): img_path = self.negative_path[idx] label = torch.tensor(0, dtype=torch.long) else: idx = idx - len(self.negative_path) img_path = self.positive_path[idx] label = torch.tensor(1, dtype=torch.long) image = read_image(img_path) image = self.resize(image) sample = {"image": image.float(), "label": label} return sample
モデルを評価するコードは次のようになりました。
import torch from torchvision.io import read_image import sys from torchvision import transforms as transforms from torch.utils.data import DataLoader from dataset import CustomImageDataset label_name = ["non-vehicle", "vehicle"] # ラベル名(使わなかったけど忘れないように) model = torch.load("./model.pth") # modelの読み込み if __name__ == "__main__": test_data = CustomImageDataset(train=False) test_dataloader = DataLoader(test_data, batch_size=1) data_iter = iter(test_dataloader) # 精度評価 tp, tn, fp, fn = 0, 0, 0, 0 for batch in data_iter: images = batch["image"] labels = batch["label"] out = model(images) y = torch.argmax(out, axis=-1) g = labels if g == 0: if y == 0: tn += 1 else: fn += 1 else: if y == 1: tp += 1 else: fp += 1 accuracy = (tp + tn)/(tp+fp+fn+tn) precision = tp / (tp+fp) recall = tp /(tp+fn) fmeature = (2*recall*precision)/(recall+precision) print("accuracy:\t",accuracy) print("precision:\t", precision) print("recall:\t\t", recall) print("F-measure:\t", fmeature)
精度評価部分では、評価データに対するTrue Positive, True Negative, False Positive, False Negativeをカウントし、AccuracyやRecallなど分類機の精度評価指標を求めています。
これらの詳細については、以下のサイトがとても詳しく解説してくれているので、ぜひご参照ください。
【入門者向け】機械学習の分類問題評価指標解説(正解率・適合率・再現率など) - Qiita
気になる精度は次のようになりました。
指標 | スコア |
---|---|
Accuracy | 0.993 |
Precision | 0.992 |
Recall | 0.994 |
F-measure | 0.993 |
各指標は1.0が最大値なので、このデータセットに対してはほとんど正しく分類できていると言えますね。
ここまで精度が良いと逆に間違ったデータが気になったので、以下にピックアップしてみました。
FN: 誤って車ではないと判断 | FP:誤って車と判断 |
---|---|
うーん、人間の目から見てもちょっとごちゃごちゃしてて判断が難しい感じがしますね。
False Negativeを見ていると、暗い背景の中に溶け込む真っ黒な車と、白飛びした背景に消えた白い車など、視覚的に認知が難しい対象はやはりCNNでも難しそうな気がしますね。
False Positiveを観察すると、左上の画像は左上に車両があるように見えます。
これはImageNetと呼ばれる大規模一般画像分類問題でもよく指摘されていた、ラベルミス問題に近しいものを感じますね。
アノテーション的には車両の背面でないため、non-vehicleが正しいのですが、これは車である!と判断されてもしょうがない様なデータですね。
まとめ
今回はモデルの保存、読み込みとデータの評価方法について勉強しました。
今回勉強した内容で、Deep Learningを扱う為に必要な最低限の技術を一通り扱えるようになりました。
pytorchは非常にチュートリアルが充実している上に、素直に実装できる大変優れたライブラリだと思いました。
次回からはpytorchも積極的に利用し、面白いものを作っていけたら良いなと思います。