2019年11月13日水曜日

fastaiでcifar10のTransfer Learningを試す

2019年11月11日の投稿で、PyTorchでcifar10のTransfer Learningを行いました。
今回は、fast.aiライブラリーを利用して行ってみたいと思います。

fast.aiのインストールなど、詳しくはここを御覧ください。



%reload_ext autoreload
%autoreload 2
%matplotlib inline



from fastai.vision import *
from fastai.metrics import error_rate




# root dirの取得
import os

if os.name == 'posix':
    root_dir='/media/{ユーザー名}/DataScience'
elif os.name == 'nt':
    root_dir = 'F:'
else:
    pass 

dataset_dir=root_dir+'/DataSet/cifar10/cifar10_png'




tfms = get_transforms(do_flip=False)

PyTorchの「data_transforms」の部分を担っているのでしょう。詳しくわ改めて書きます。


path = dataset_dir
data=ImageDataBunch.from_folder(path,'train','val',ds_tfms=tfms, size=224,bs=16)
data.normalize(imagenet_stats)


imagenet_statsは、ImageNetに合わせる正規化、PyTorchで「transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])」の部分を担っているのだと思います。



learn = cnn_learner(data, models.resnet34, metrics=error_rate,callback_fns=ShowGraph)


【注意】create_cnn is deprecated and is now named cnn_learner



learn.lr_find()


learn.recorder.plot()


今日の場合、上2つは必要ないでしょうが、とりあえず書いておきます。
fast.aiでは、トレーニングに際して、fit_one_cycleを推奨しているようです。
fit_one_cycleについても、改めて検証したいと思います。



learn.fit_one_cycle(4)


epoch train_loss  valid_loss  error_rate  time
0   0.516307  0.283570  0.097300  07:50
1   0.393534  0.194867  0.066300  07:48
2   0.351924  0.166043  0.057800  07:49
3   0.280021  0.160570  0.055000  07:48




  • たったこれだけのコードで、前回PyTorchで書いたTransfer Learningとほぼ同等な結果が得られました。
  • 自分で色々とカスタマイズしたい場合は、fast.aiではブラックボックス化されているので、PyTorchで書くしかないでしょう。fast.aiでも出来るのかも知れませんが勉強不足でそこまで分かりません。
  • ただ、トレーニング時間が31分と、PyTorchの23分よりかなり多くかかっています。

0 件のコメント:

コメントを投稿