index画像からone-hotベクトル画像を作る方法
やりたいこと
Image Segmentationはpixel単位でカテゴリIDを予測するタスクです。
教師画像の形式はLoss関数の関係上、カテゴリIDをそのまま扱うのではなく、One-Hot-Vectorで扱います。
One-Hot-VectorはカテゴリIDを2進数の様に扱うデータ表現であり、Neural Networkで扱いやすい(学習しやすい)表現方法となります。
普通のデータ表現 | One-hot Vector |
---|---|
この例では、スカラー値のラベル表現ですが、One-hot-Vectorを画像で行うと次のようになります。
普通のデータ表現 | One-Hot-Vector Image |
---|---|
この様な感じで、pixel単位でカテゴリIDが記録されているグレースケール画像から、One-Hot-Vector画像(Category数, H, W)をPytorchで作成するコードを紹介します。
実際のコード
import torch from torchvision.io import read_image from torchvision import transforms from torch.nn.functional import one_hot # グレースケールでラベル画像を読み込む label_img = read_image("label.png")[0] # label_img.shape -> (H, W) C = 10 # クラス数 # ラベル画像からOne-Hot-Vectorに変換 label_onehot = one_hot(label_img.long(), num_classes=C) #channel_last(H, W, C)になってるので、Channel Firstに変換 label_onehot = torch.permute(label_onehot, (2, 0, 1))
pytorchにはone_hotというドンピシャの関数があり非常に助かりました。