その中で、fast.aiではfittingを行うのに.fit_one_cycle()を推奨しており詳しくは改めて・・・としておりました。
今までの私が理解した範囲では、Learning rate(学習率)がfittingの精度を大きく左右するということです。
ここで、fast.aiでのtraining手順をまとめると、
①learn.lr_finder()で学習率の目安を探索
②learn.recorder.plotで学習率をプロットし、最適な学習率を可視化する。
③learn.fit_one_cycle(●●,max_lr=○○)でtraining
●●はepochs数、○○は②て求めた学習率。max_lr以外にもslice( )もあります。
④learn.unfreeze()でパラメーターの固定解除
⑤再度learn.lr_finder()で学習率の目安を探索
⑥learn.fit_one_cycle(●●,slice(◆◆,◇◇)でtraining
◆◆は⑤で選んだ学習率。選び方は②と同じ一番小さいLossのLearning rateの10分の1の値。
◇◇は②で求めたパラメータ固定時の学習率の5〜10分の1の値
大体このような流れではないでしょうか。
ここで実際にやってみてlearn.recorder.plotで学習率をプロットして、学習率として幾らの値を採用すればよいのか迷い、不安に思うところです。
そこで学習率を意図的に色々と変えて試してみました。また学習率だけでなくlearn.fit_one_cycleに出てくるパラメータのpct_startも変化させて調べてみました。
試験に用いたサンプルデータ
試験には下記のデータを用いました。
A.cifar10
おなじみのデータセットです。
Train: LabelList (50000 items)
Valid: LabelList (10000 items)
クラス数: 2
クラス名
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
B.MNIST
これもおなじみです。
Train: LabelList (60000 items)
Valid: LabelList (10000 items)
クラス数: 2
クラス名:
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
C.Oxford-IIIT-Pet
これもよく見かけるデータセットです。
Train: LabelList (5912 items)
Valid: LabelList (1478 items)
クラス数: 37
クラス名:
['Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
D.imagenette2
有名なImageNetデータセットの中の10クラス
Train: LabelList (9469 items)
Valid: LabelList (3925 items)
クラス数: 10
クラス名:
['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']
E.dogscats
Train: LabelList (23000 items)
Valid: LabelList (2000 items)
クラス数: 2
クラス名:
['cats', 'dogs']
F.dogs-cats
Eのサブセット版か?忘れてしまいました。
Train: LabelList (856 items)
Valid: LabelList (213 items)
クラス数: 2
クラス名:
['cats', 'dogs']
G.insects
私がネットから集めた害虫の画像データ
Train: LabelList (5000 items)
Valid: LabelList (1250 items)
クラス数: 8
クラス名:
['ants', 'bees', 'centipede', 'cockroach', 'fly', 'mosquito', 'spider', 'stink']
これらの8種類のデータを使いました。
A〜Dは、fast.aiのdatasetにもあります。
今日はここまで、次回試験に使ったプログラムを紹介します。
0 件のコメント:
コメントを投稿