2020年2月12日水曜日

より精度の高いfittingの為のLearning rateの設定について(その1)

2019年11月13日の投稿で、PyTorchと比較してfast.aiでcifar10のTransfer Learningを行い、如何にfast.aiが少ないコードで実行可能かを紹介しました。

その中で、fast.aiではfittingを行うのに.fit_one_cycle()を推奨しており詳しくは改めて・・・としておりました

今までの私が理解した範囲では、Learning rate(学習率)がfittingの精度を大きく左右するということです。


ここで、fast.aiでのtraining手順をまとめると、

①learn.lr_finder()で学習率の目安を探索
②learn.recorder.plotで学習率をプロットし、最適な学習率を可視化する。

これは、learn.recorder.plotの例です。ネットを色々と見てみると、一番小さいLossのLearning rateの10分の1の値とか、一番小さいLossのLearning rateに至る傾きの急なところを採用とかあります。まあ3e-4あたりでしょうか。
③learn.fit_one_cycle(●●,max_lr=○○)でtraining
 ●●はepochs数、○○は②て求めた学習率。max_lr以外にもslice( )もあります。
④learn.unfreeze()でパラメーターの固定解除
⑤再度learn.lr_finder()で学習率の目安を探索
⑥learn.fit_one_cycle(●●,slice(◆◆,◇◇)でtraining
 ◆◆は⑤で選んだ学習率。選び方は②と同じ一番小さいLossのLearning rateの10分の1の値。
 ◇◇は②で求めたパラメータ固定時の学習率の5〜10分の1の値

大体このような流れではないでしょうか。

ここで実際にやってみてlearn.recorder.plotで学習率をプロットして、学習率として幾らの値を採用すればよいのか迷い、不安に思うところです。

そこで学習率を意図的に色々と変えて試してみました。また学習率だけでなくlearn.fit_one_cycleに出てくるパラメータのpct_startも変化させて調べてみました。



試験に用いたサンプルデータ
試験には下記のデータを用いました。

A.cifar10
おなじみのデータセットです。
Train: LabelList (50000 items)
Valid: LabelList (10000 items)
クラス数: 2
クラス名
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

B.MNIST
これもおなじみです。
Train: LabelList (60000 items)
Valid: LabelList (10000 items)
クラス数: 2
クラス名: 
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

C.Oxford-IIIT-Pet
これもよく見かけるデータセットです。
Train: LabelList (5912 items)
Valid: LabelList (1478 items)
クラス数: 37
クラス名: 
['Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']

D.imagenette2
有名なImageNetデータセットの中の10クラス
Train: LabelList (9469 items)
Valid: LabelList (3925 items)
クラス数: 10
クラス名: 
['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']

E.dogscats
Train: LabelList (23000 items)
Valid: LabelList (2000 items)
クラス数: 2
クラス名: 
['cats', 'dogs']

F.dogs-cats
Eのサブセット版か?忘れてしまいました。
Train: LabelList (856 items)
Valid: LabelList (213 items)
クラス数: 2
クラス名: 
['cats', 'dogs']

G.insects
私がネットから集めた害虫の画像データ
Train: LabelList (5000 items)
Valid: LabelList (1250 items)
クラス数: 8
クラス名:
['ants', 'bees', 'centipede', 'cockroach', 'fly', 'mosquito', 'spider', 'stink']


これらの8種類のデータを使いました。
A〜Dは、fast.aiのdatasetにもあります。


今日はここまで、次回試験に使ったプログラムを紹介します。




0 件のコメント:

コメントを投稿