2019年12月25日水曜日

Transfer learningで更新されるパラメーターの確認(その続き)

fast.aiを使うとこんなに少ないコードでタスクを達成できることを見てきましたが、前回2019/12/20の投稿でTransfer Learningでパラメーター更新について少し勉強ました。
ここでは、fastaiでのTransfer Learningで、learn.freeze()とlearn.unfreeze()とでの結果の違いについて確認しておきます。
前回も書きましたが、fast.aiでは、2段階でトレーニングの精度を上げているようです。

1.まずは、そのままlearn.fit_one_cycle(4)でトレーニング
2.その後learn.unfreeze()し、learn.fit_one_cycle(4)でトレーニング



%reload_ext autoreload
%autoreload 2
%matplotlib inline


from fastai.vision import *
from fastai.metrics import error_rate,accuracy


dir='F:/DataSet'


tfms = get_transforms(do_flip=False)


path = untar_data(URLs.MNIST,dest=dir); path
path = dir+'/cifar10/cifar10_png'; path
data=ImageDataBunch.from_folder(path,'train','val',ds_tfms=tfms, size=224,bs=16)
data.normalize(imagenet_stats)


learn1 = cnn_learner(data, models.resnet34, metrics=accuracy,callback_fns=ShowGraph)
learn2 = cnn_learner(data, models.resnet34, metrics=accuracy,callback_fns=ShowGraph)

  • learn1では、一度デフォルトのlearn.freeze()でトレーニングし、その後unfreeze()して再度トレーニング
  • learn2では、初めからunfreeze()してトレーニング

の2つ実行してみます。
learn1.lr_find()


learn1.recorder.plot()



1.learn.freeze()でトレーニング


learn1.fit_one_cycle(4)

epoch     train_loss     valid_loss     accuracy     time
0            0.536407     0.295173     0.899200     07:56
1            0.350948     0.191332     0.936000     07:53
2            0.330007     0.162405     0.945200     07:56
3            0.310012     0.159538     0.943000     07:55




learn1.unfreeze()
さらにmodelをunfreeze()してtrainingを続ける

learn1.lr_find()
learn1.recorder.plot()


Loss-Learning Rateのグラフの様子が、learn1.freeze()のときと異なっています。

このこともfast.aiのv3レクチャーで触れられていたと思います。
また、本来ならこの結果を踏まえてLearning Rateを設定して以下のトレーニングを継続するのでしょうがそれはまた改めて


learn1.fit_one_cycle(4)

epoch    train_loss     valid_loss     accuracy      time
0          0.675299     0.545486     0.815900     09:27
1          0.473109     0.297044     0.899300     09:26
2          0.257993     0.172472     0.940100     09:31
3          0.121066     0.137457     0.954700     09:28




2.learn.unfreeze()でトレーニング

learn2.unfreeze()


learn2.lr_find()
learn2.recorder.plot()



learn2.fit_one_cycle(4)

epoch    train_loss   valid_loss   accuracy     time
0           0.775636   0.665999   0.778800    09:29
1           0.512648   0.335976   0.884100    09:24
2           0.270773   0.183393   0.936500    09:29
3           0.158724   0.141045   0.953500    09:27





  • unfreeze()することで、若干accuracyが改善されました。
  • freeeze()→unfreeze()と2段階でトレーニングしても、初めからunfreeze()してトレーニングしても、今回はあまり変化はありませんでした。
  • freeeze()→unfreeze()と2段階でトレーニングの際に、学習率を設定して再トレーニングすれば違った結果が得られたのかも知れませんが、それに関してはまた改めて・・・