社会人研究者が色々頑張るブログ

pythonで画像処理やパターン認識をやっていきます

index画像からone-hotベクトル画像を作る方法

やりたいこと

Image Segmentationはpixel単位でカテゴリIDを予測するタスクです。
教師画像の形式はLoss関数の関係上、カテゴリIDをそのまま扱うのではなく、One-Hot-Vectorで扱います。

One-Hot-VectorはカテゴリIDを2進数の様に扱うデータ表現であり、Neural Networkで扱いやすい(学習しやすい)表現方法となります。

普通のデータ表現 One-hot Vector
f:id:nsr_9:20210902161118p:plain f:id:nsr_9:20210902161124p:plain

この例では、スカラー値のラベル表現ですが、One-hot-Vectorを画像で行うと次のようになります。

普通のデータ表現 One-Hot-Vector Image
f:id:nsr_9:20210902162919j:plain f:id:nsr_9:20210902163107p:plain

この様な感じで、pixel単位でカテゴリIDが記録されているグレースケール画像から、One-Hot-Vector画像(Category数, H, W)をPytorchで作成するコードを紹介します。

実際のコード

import torch
from torchvision.io import read_image
from torchvision import transforms
from torch.nn.functional import one_hot


# グレースケールでラベル画像を読み込む
label_img = read_image("label.png")[0]
# label_img.shape -> (H, W)

C = 10 # クラス数

# ラベル画像からOne-Hot-Vectorに変換
label_onehot = one_hot(label_img.long(), num_classes=C)

#channel_last(H, W, C)になってるので、Channel Firstに変換
label_onehot = torch.permute(label_onehot, (2, 0, 1))

pytorchにはone_hotというドンピシャの関数があり非常に助かりました。