2019年12月20日金曜日

Transfer learningで更新されるパラメーターの確認


Transfer Learningでの事前トレーニングされたパラメーターのweightsの効果を見てきました。
Transfer Learningする上で、pretrained=Trueは必須であると言うかそうでないと、そもそもTransfer Learningの意味がない事がわかりました。

fast.aiでは、2段階でトレーニングの精度を上げているようです。
1.まずは、そのままlearn.fit_one_cycle(4)でトレーニング

2.その後learn.unfreeze()し、learn.fit_one_cycle(4)でトレーニング

1と2との間で、本来ならlearn.fit_one_cycle()の肝である学習率の再調整が必要なのでしょうが、それはまた改めて。

2のunfreeze()は何をしているかというと、パラメーターのロックを外して、全てのパラメーターが再学習(更新)できるようにしています。言い換えると1では、ロックがかかり更新できないパラメーターがあると言うことです。

ここではとりあえず、1と2とでどのようなパラメーターが更新されるのか確認しておきます。



from fastai.vision import *
from fastai.metrics import error_rate





dataset_dir='F:/DataSet/cifar10/cifar10_png'





tfms = get_transforms(do_flip=False)





path = dataset_dir
data=ImageDataBunch.from_folder(path,'train','val',ds_tfms=tfms, size=224,bs=16)
data.normalize(imagenet_stats)





learn = cnn_learner(data, models.resnet34, metrics=error_rate,callback_fns=ShowGraph)





for name,param in learn.model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)



★デフォルトのlearn.freeze()で更新されるパラメーター
0.1.weight
0.1.bias
0.4.0.bn1.weight
0.4.0.bn1.bias
0.4.0.bn2.weight
0.4.0.bn2.bias
0.4.1.bn1.weight
0.4.1.bn1.bias
0.4.1.bn2.weight
0.4.1.bn2.bias
0.4.2.bn1.weight
0.4.2.bn1.bias
0.4.2.bn2.weight
0.4.2.bn2.bias
0.5.0.bn1.weight
0.5.0.bn1.bias
0.5.0.bn2.weight
0.5.0.bn2.bias
0.5.0.downsample.1.weight
0.5.0.downsample.1.bias
0.5.1.bn1.weight
0.5.1.bn1.bias
0.5.1.bn2.weight
0.5.1.bn2.bias
0.5.2.bn1.weight
0.5.2.bn1.bias
0.5.2.bn2.weight
0.5.2.bn2.bias
0.5.3.bn1.weight
0.5.3.bn1.bias
0.5.3.bn2.weight
0.5.3.bn2.bias
0.6.0.bn1.weight
0.6.0.bn1.bias
0.6.0.bn2.weight
0.6.0.bn2.bias
0.6.0.downsample.1.weight
0.6.0.downsample.1.bias
0.6.1.bn1.weight
0.6.1.bn1.bias
0.6.1.bn2.weight
0.6.1.bn2.bias
0.6.2.bn1.weight
0.6.2.bn1.bias
0.6.2.bn2.weight
0.6.2.bn2.bias
0.6.3.bn1.weight
0.6.3.bn1.bias
0.6.3.bn2.weight
0.6.3.bn2.bias
0.6.4.bn1.weight
0.6.4.bn1.bias
0.6.4.bn2.weight
0.6.4.bn2.bias
0.6.5.bn1.weight
0.6.5.bn1.bias
0.6.5.bn2.weight
0.6.5.bn2.bias
0.7.0.bn1.weight
0.7.0.bn1.bias
0.7.0.bn2.weight
0.7.0.bn2.bias
0.7.0.downsample.1.weight
0.7.0.downsample.1.bias
0.7.1.bn1.weight
0.7.1.bn1.bias
0.7.1.bn2.weight
0.7.1.bn2.bias
0.7.2.bn1.weight
0.7.2.bn1.bias
0.7.2.bn2.weight
0.7.2.bn2.bias
1.2.weight
1.2.bias
1.4.weight
1.4.bias
1.6.weight
1.6.bias
1.8.weight
1.8.bias





learn.unfreeze()





for name,param in learn.model.named_parameters():
        if param.requires_grad == True:
            print("\t",name)



★learn.unfreeze()で更新されるパラメーター
0.0.weight
0.1.weight
0.1.bias
0.4.0.conv1.weight
0.4.0.bn1.weight
0.4.0.bn1.bias
0.4.0.conv2.weight
0.4.0.bn2.weight
0.4.0.bn2.bias
0.4.1.conv1.weight
0.4.1.bn1.weight
0.4.1.bn1.bias
0.4.1.conv2.weight
0.4.1.bn2.weight
0.4.1.bn2.bias
0.4.2.conv1.weight
0.4.2.bn1.weight
0.4.2.bn1.bias
0.4.2.conv2.weight
0.4.2.bn2.weight
0.4.2.bn2.bias
0.5.0.conv1.weight
0.5.0.bn1.weight
0.5.0.bn1.bias
0.5.0.conv2.weight
0.5.0.bn2.weight
0.5.0.bn2.bias
0.5.0.downsample.0.weight
0.5.0.downsample.1.weight
0.5.0.downsample.1.bias
0.5.1.conv1.weight
0.5.1.bn1.weight
0.5.1.bn1.bias
0.5.1.conv2.weight
0.5.1.bn2.weight
0.5.1.bn2.bias
0.5.2.conv1.weight
0.5.2.bn1.weight
0.5.2.bn1.bias
0.5.2.conv2.weight
0.5.2.bn2.weight
0.5.2.bn2.bias
0.5.3.conv1.weight
0.5.3.bn1.weight
0.5.3.bn1.bias
0.5.3.conv2.weight
0.5.3.bn2.weight
0.5.3.bn2.bias
0.6.0.conv1.weight
0.6.0.bn1.weight
0.6.0.bn1.bias
0.6.0.conv2.weight
0.6.0.bn2.weight
0.6.0.bn2.bias
0.6.0.downsample.0.weight
0.6.0.downsample.1.weight
0.6.0.downsample.1.bias
0.6.1.conv1.weight
0.6.1.bn1.weight
0.6.1.bn1.bias
0.6.1.conv2.weight
0.6.1.bn2.weight
0.6.1.bn2.bias
0.6.2.conv1.weight
0.6.2.bn1.weight
0.6.2.bn1.bias
0.6.2.conv2.weight
0.6.2.bn2.weight
0.6.2.bn2.bias
0.6.3.conv1.weight
0.6.3.bn1.weight
0.6.3.bn1.bias
0.6.3.conv2.weight
0.6.3.bn2.weight
0.6.3.bn2.bias
0.6.4.conv1.weight
0.6.4.bn1.weight
0.6.4.bn1.bias
0.6.4.conv2.weight
0.6.4.bn2.weight
0.6.4.bn2.bias
0.6.5.conv1.weight
0.6.5.bn1.weight
0.6.5.bn1.bias
0.6.5.conv2.weight
0.6.5.bn2.weight
0.6.5.bn2.bias
0.7.0.conv1.weight
0.7.0.bn1.weight
0.7.0.bn1.bias
0.7.0.conv2.weight
0.7.0.bn2.weight
0.7.0.bn2.bias
0.7.0.downsample.0.weight
0.7.0.downsample.1.weight
0.7.0.downsample.1.bias
0.7.1.conv1.weight
0.7.1.bn1.weight
0.7.1.bn1.bias
0.7.1.conv2.weight
0.7.1.bn2.weight
0.7.1.bn2.bias
0.7.2.conv1.weight
0.7.2.bn1.weight
0.7.2.bn1.bias
0.7.2.conv2.weight
0.7.2.bn2.weight
0.7.2.bn2.bias
1.2.weight
1.2.bias
1.4.weight
1.4.bias
1.6.weight
1.6.bias
1.8.weight
1.8.bias



両者の違いが分かりにくいので並べてみました。




Freeze(80)
Unfreeze(116)

0.0.weight
0.1.weight
0.1.weight
0.1.bias
0.1.bias

0.4.0.conv1.weight
0.4.0.bn1.weight
0.4.0.bn1.weight
0.4.0.bn1.bias
0.4.0.bn1.bias

0.4.0.conv2.weight
0.4.0.bn2.weight
0.4.0.bn2.weight
0.4.0.bn2.bias
0.4.0.bn2.bias

0.4.1.conv1.weight
0.4.1.bn1.weight
0.4.1.bn1.weight
0.4.1.bn1.bias
0.4.1.bn1.bias

0.4.1.conv2.weight
0.4.1.bn2.weight
0.4.1.bn2.weight
0.4.1.bn2.bias
0.4.1.bn2.bias

0.4.2.conv1.weight
0.4.2.bn1.weight
0.4.2.bn1.weight
0.4.2.bn1.bias
0.4.2.bn1.bias

0.4.2.conv2.weight
0.4.2.bn2.weight
0.4.2.bn2.weight
0.4.2.bn2.bias
0.4.2.bn2.bias

0.5.0.conv1.weight
0.5.0.bn1.weight
0.5.0.bn1.weight
0.5.0.bn1.bias
0.5.0.bn1.bias

0.5.0.conv2.weight
0.5.0.bn2.weight
0.5.0.bn2.weight
0.5.0.bn2.bias
0.5.0.bn2.bias

0.5.0.downsample.0.weight
0.5.0.downsample.1.weight
0.5.0.downsample.1.weight
0.5.0.downsample.1.bias
0.5.0.downsample.1.bias

0.5.1.conv1.weight
0.5.1.bn1.weight
0.5.1.bn1.weight
0.5.1.bn1.bias
0.5.1.bn1.bias

0.5.1.conv2.weight
0.5.1.bn2.weight
0.5.1.bn2.weight
0.5.1.bn2.bias
0.5.1.bn2.bias

0.5.2.conv1.weight
0.5.2.bn1.weight
0.5.2.bn1.weight
0.5.2.bn1.bias
0.5.2.bn1.bias

0.5.2.conv2.weight
0.5.2.bn2.weight
0.5.2.bn2.weight
0.5.2.bn2.bias
0.5.2.bn2.bias

0.5.3.conv1.weight
0.5.3.bn1.weight
0.5.3.bn1.weight
0.5.3.bn1.bias
0.5.3.bn1.bias

0.5.3.conv2.weight
0.5.3.bn2.weight
0.5.3.bn2.weight
0.5.3.bn2.bias
0.5.3.bn2.bias

0.6.0.conv1.weight
0.6.0.bn1.weight
0.6.0.bn1.weight
0.6.0.bn1.bias
0.6.0.bn1.bias

0.6.0.conv2.weight
0.6.0.bn2.weight
0.6.0.bn2.weight
0.6.0.bn2.bias
0.6.0.bn2.bias

0.6.0.downsample.0.weight
0.6.0.downsample.1.weight
0.6.0.downsample.1.weight
0.6.0.downsample.1.bias
0.6.0.downsample.1.bias

0.6.1.conv1.weight
0.6.1.bn1.weight
0.6.1.bn1.weight
0.6.1.bn1.bias
0.6.1.bn1.bias

0.6.1.conv2.weight
0.6.1.bn2.weight
0.6.1.bn2.weight
0.6.1.bn2.bias
0.6.1.bn2.bias

0.6.2.conv1.weight
0.6.2.bn1.weight
0.6.2.bn1.weight
0.6.2.bn1.bias
0.6.2.bn1.bias

0.6.2.conv2.weight
0.6.2.bn2.weight
0.6.2.bn2.weight
0.6.2.bn2.bias
0.6.2.bn2.bias

0.6.3.conv1.weight
0.6.3.bn1.weight
0.6.3.bn1.weight
0.6.3.bn1.bias
0.6.3.bn1.bias

0.6.3.conv2.weight
0.6.3.bn2.weight
0.6.3.bn2.weight
0.6.3.bn2.bias
0.6.3.bn2.bias

0.6.4.conv1.weight
0.6.4.bn1.weight
0.6.4.bn1.weight
0.6.4.bn1.bias
0.6.4.bn1.bias

0.6.4.conv2.weight
0.6.4.bn2.weight
0.6.4.bn2.weight
0.6.4.bn2.bias
0.6.4.bn2.bias

0.6.5.conv1.weight
0.6.5.bn1.weight
0.6.5.bn1.weight
0.6.5.bn1.bias
0.6.5.bn1.bias

0.6.5.conv2.weight
0.6.5.bn2.weight
0.6.5.bn2.weight
0.6.5.bn2.bias
0.6.5.bn2.bias

0.7.0.conv1.weight
0.7.0.bn1.weight
0.7.0.bn1.weight
0.7.0.bn1.bias
0.7.0.bn1.bias

0.7.0.conv2.weight
0.7.0.bn2.weight
0.7.0.bn2.weight
0.7.0.bn2.bias
0.7.0.bn2.bias

0.7.0.downsample.0.weight
0.7.0.downsample.1.weight
0.7.0.downsample.1.weight
0.7.0.downsample.1.bias
0.7.0.downsample.1.bias

0.7.1.conv1.weight
0.7.1.bn1.weight
0.7.1.bn1.weight
0.7.1.bn1.bias
0.7.1.bn1.bias

0.7.1.conv2.weight
0.7.1.bn2.weight
0.7.1.bn2.weight
0.7.1.bn2.bias
0.7.1.bn2.bias

0.7.2.conv1.weight
0.7.2.bn1.weight
0.7.2.bn1.weight
0.7.2.bn1.bias
0.7.2.bn1.bias

0.7.2.conv2.weight
0.7.2.bn2.weight
0.7.2.bn2.weight
0.7.2.bn2.bias
0.7.2.bn2.bias
1.2.weight
1.2.weight
1.2.bias
1.2.bias
1.4.weight
1.4.weight
1.4.bias
1.4.bias
1.6.weight
1.6.weight
1.6.bias
1.6.bias
1.8.weight
1.8.weight
1.8.bias
1.8.bias



learn.freeze()で80個、learn.unfreeze()で116個のパラメーターが更新されることが分かりました。
だから何なの?と言われそうですが、Transfer leaningでは、どのパラメーターを再学習(更新)させるかによって、トレーニングの結果に大きく影響するようです。ここは、Jeremy Howard先生の長年の経験によって、パラメーターの固定(freeze)を選ばれているのだと思います。
興味のある方は、PyTorchでの手作りになると思いますが(?)色々と試されては如何でしょうか。私は今のところこれ以上深入りする予定はございません。


0 件のコメント:

コメントを投稿