2020年2月29日土曜日

より精度の高いfittingの為のLearning rateの設定について(その2)



前回の続きです。今回は、プログラムを実装します。
とはいっても大したプログラムではありませんし、泥臭く同じことの繰り返しですがお許しください。
前回述べたCIFAR10、MNIST、など6種類のデータセットについて確認しますが、CIFAR10についての実施の様子を下記に記録します。


%reload_ext autoreload
%autoreload 2
%matplotlib inline


from fastai.vision import *
from fastai.metrics import error_rate



Data Preparation
path = untar_data(URLs.CIFAR); path


tfms = get_transforms(do_flip=False)


data=ImageDataBunch.from_folder(path,'train','test',ds_tfms=tfms, size=224,bs=16)
data.normalize(imagenet_stats)


データサンプルの表示
data.show_batch(rows=3,figsize=(5,5))






















クラス名の表示

print(data.classes)


['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
クラスの数

len(data.classes)


10
クラス数は、data.cでも求めることができます。


import torchvision.models as models


Learning rate finder
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.lr_find()
learn1.recorder.plot()
learn1.recorder.plot_lr()

前回、fast.aiでのtraining手順として、

①learn.lr_finder()で学習率の目安を探索
②learn.recorder.plotで学習率をプロットし、最適な学習率を可視化する。
③learn.fit_one_cycle(●●,max_lr=○○)でtraining
④learn.unfreeze()でパラメーターの固定解除
⑤再度learn.lr_finder()で学習率の目安を探索
⑥learn.fit_one_cycle(●●,slice(◆◆,◇◇)でtraining

と書きましたが、ここでは③をとばしていきなり④のunfreeze()を行います。






































learn1.fit_one_cycle(5)
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5)
learn1.recorder.plot_lr()


epochtrain_lossvalid_lossaccuracy
00.7802160.5839920.808
10.5887310.4147290.8604
20.3527150.2340250.925
30.2038240.1508890.9482
40.1248590.1326760.9565
























学習率を指定しない場合の学習率の初期値は3e-3のようです。
また、あとで調べるpct_startについては0.30のようです。

learn1.fit_one_cycle(5,◆◆)を変化させて
●1e-1

learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-1)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
03.2954674.3463090.1173
13.9215724.9835870.1723
22.9598836.0407820.2201
31.766383.8369340.4246
41.3303341.1977810.5942
























●1e-2
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-2)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
01.7029641.5655720.4127
11.3887721.1417360.6113
21.0846910.9528960.7039
30.8142221.503640.7882
40.6791510.8662660.817
























●1e-3
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-3)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
01.0435670.8802540.7031
10.7579480.6194970.7985
20.5082530.3421010.8855
30.3231110.2302930.9217
40.225520.1939240.9358
























●1e-4
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-4)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.6373540.363550.886
10.4245170.2568430.9117
20.2918430.1801950.9392
30.1330930.1251770.9568
40.0682660.1083560.9646























●1e-5
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-5)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
01.2370.6282050.793
10.567210.2447850.9185
20.4431360.1866890.938
30.3524760.1639960.9453
40.3609520.1697880.9435























●1e-6
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,1e-6)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
03.0515612.2382220.2492
11.9772031.1724920.6117
21.5019980.8055710.7329
31.3161910.7259150.7575
41.3086560.6913760.7698























learn1.fit_one_cycle(5,◆◆,◇◇)を変化させて
●(1e-3,1e-2)

learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,slice(1e-3,1e-2))
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
01.2238290.9471840.6773
10.8926430.8206860.7632
20.6358280.5862260.8448
30.4095290.386310.8978
40.2825910.2914270.9132



































●(1e-4,1e-3)

learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,slice(1e-4,1e-3))
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.7318750.4670840.8481
10.4917960.3313050.8871
20.2998850.2074890.9278
30.1589140.1544690.9482
40.100620.1319410.9585























●(1e-5,1e-4)
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,slice(1e-5,1e-4))
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.6035840.29280.9025
10.4291660.1914360.9351
20.2124770.1402530.9519
30.1180280.1151150.9613
40.1039560.1091390.9646

























●(1e-6,1e-5)
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5,slice(1e-6,1e-5))
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
01.8383871.0902590.6381
10.8511550.3943550.8689
20.6691680.3060720.9017
30.5216870.2792490.9091
40.5634960.2701890.913

























pct_startを変化させて
●pct_start=0.05

learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5, slice(1e-4,1e-3), pct_start=0.05)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.5802750.4829660.8361
10.3911750.2634710.9113
20.2155970.1926170.9355
30.1407550.1373280.957
40.0897470.1328450.9589
























●pct_start=0.10
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5, slice(1e-4,1e-3), pct_start=0.10)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.6041410.4270040.8577
10.374160.2614630.9141
20.2323740.166950.9448
30.1430510.1413180.9537
40.0893410.1328350.9567
























●pct_start=0.30
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5, slice(1e-4,1e-3), pct_start=0.30)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.7504670.421110.852
10.4989620.3520080.8837
20.3332380.22740.9265
30.1581430.1453930.9523
40.1044810.1314250.9571

























● pct_start=0.50
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5, slice(1e-4,1e-3), pct_start=0.50)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.6594090.338610.89
10.5768970.3964110.8642
20.452040.2958520.8992
30.2137370.1729170.9407
40.1160440.1419520.954
























●pct_start=0.75
learn1 = cnn_learner(data, models.resnet34, metrics=accuracy, callback_fns=ShowGraph)
learn1.unfreeze()
learn1.fit_one_cycle(5, slice(1e-4,1e-3), pct_start=0.75)
learn1.recorder.plot_lr()

epochtrain_lossvalid_lossaccuracy
00.50360.2600670.9109
10.5331160.3432040.891
20.5411150.336610.8864
30.3867340.2473020.9149
40.1809720.1383090.9537
























結果

上記にepoch、またそのグラフを示していますので大体の様子はお解りと思います。
これは、cifar10データセットの結果です。
その他のデータセット(imagenette2、MNIST・・・)の結果もまとめてグラフにしました。

まず、learn1.fit_one_cycle(5,◆◆)を変化させた場合、いずれのデータセットも学習率としては1e-4乃至は1e-5が最適のようです。
また、MNISTは学習率に対してロバストネス(robustness)と言えるのでしょうか。


1.00E-011.00E-021.00E-031.00E-041.00E-051.00E-06
imagenette20.2667520.5977070.9149040.9785990.9844590.81707
MNIST0.99480.99570.99710.99690.99510.9837
dogs-cats0.5070420.6291080.8262910.9812210.9483570.657277
dogscats0.6390.7180.96850.9860.99350.982
cifar100.59420.8170.93580.96460.94350.7698
PET0.0460080.1143440.8403250.9242220.8714480.121786
insect0.2040.28160.48160.89040.85680.3192
























つぎに、learn1.fit_one_cycle(5,◆◆,◇◇)を変化させた場合、学習率としては(1e-5,1e-4)が最適のようです。

1e-3,1e-21e-4,1e-31e-5,1e-41e-6,1e-5
imagenette20.896560.9661150.9892990.974777
MNIST0.99530.99660.99690.9941
dogs-cats0.582160.9812210.9765260.920188
dogscats0.96050.98650.99050.992
cifar100.91320.95850.96460.913
PET0.7395130.9120430.9289580.702977
insect0.3920.8680.90240.7376






















さいごに、pct_startを変化させた場合、いづれのデータセットもあまり大きな変化は認められない結果となりました。
これが偶然なのかどうかは分かりません。pct_startを変化させるときの学習率(1e-4,1e-3)が適当な値だったため、pct_startの効果が現れにくかったのか、違う学習率だったらもっと差がでたのか不明です。

もっとも今回のような条件設定での実験のやり方は、あまり薦められるものではなく、詳しくは実験計画法などを用いて検討をするべきでしょう。簡便にやりすぎたのかも知れません。

0.050.100.300.500.75
imagenette20.9694270.9686620.9701910.9633120.963057
MNIST0.99630.99690.99730.99640.997
dogs-cats0.9718310.990610.9765260.9718310.967136
dogscats0.98550.98550.9850.98750.9855
cifar100.95890.95670.95710.9540.9537
PET0.9194860.9248990.9073070.914750.897158
insect0.88160.88320.89120.88080.868























ここで、6種類のデータセットのLearning rate finderでの結果を見てみます。
CIFAR10























MNIST























Oxford-IIIT-Pet























imagenette2
























dogscats
























dogs-cats























まとめ

●以外だったのは(と言うか知識不足か偶然か)いづれのデータセットもLearning rate finderでのボトムは、ほぼ1e-3~1e-2のあたりにある
最適な学習率はほぼ1e-5~1e-4あたり
●もちろん1e-5~1e-4あたりで、3e-5にするとかpct_startを変えるとかの微調整をすれば、若干改善する余地がある。

まだまだ私の経験不足で今後どのように変化するか分かりませんが、今のところ学習率に関しては上記のような結論です。