今回は、fast.aiライブラリーを利用して行ってみたいと思います。
fast.aiのインストールなど、詳しくはここを御覧ください。
PyTorchの「data_transforms」の部分を担っているのでしょう。詳しくわ改めて書きます。
imagenet_statsは、ImageNetに合わせる正規化、PyTorchで「transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])」の部分を担っているのだと思います。
【注意】
create_cnn
is deprecated and is now named cnn_learner
fast.aiでは、トレーニングに際して、fit_one_cycleを推奨しているようです。
fit_one_cycleについても、改めて検証したいと思います。
epoch train_loss valid_loss error_rate time
0 0.516307 0.283570 0.097300 07:50
1 0.393534 0.194867 0.066300 07:48
2 0.351924 0.166043 0.057800 07:49
3 0.280021 0.160570 0.055000 07:48
- たったこれだけのコードで、前回PyTorchで書いたTransfer Learningとほぼ同等な結果が得られました。
- 自分で色々とカスタマイズしたい場合は、fast.aiではブラックボックス化されているので、PyTorchで書くしかないでしょう。fast.aiでも出来るのかも知れませんが勉強不足でそこまで分かりません。
- ただ、トレーニング時間が31分と、PyTorchの23分よりかなり多くかかっています。
0 件のコメント:
コメントを投稿