Intro
確率的勾配降下法のサンプルをネットで調べていたところ、検索結果のトップ群にQiitaの記事になかなか収束しないという話があったが、もろもろ調べて考察した結果、事前処理としてデータの正規化をすれば収束性が改善するらしい。ここでは事前処理のありなしで収束性がどう変わるか比較する。
Wikipediaの確率的勾配降下法のページの平均・分散の調整セクションに、”訓練データの正規化をして、平均0分散1になるように調整する方が良い”とある。これは、あまりにも大きさの異なるデータがあるとコスト関数の形状が潰れてしまい収束性が悪くなってしまうため、整形する目的で訓練データを定数で加減乗除して収束性の改善を図る。この定数に平均や分散を使う。
問題設定・正規化前の学習則
Qiitaの記事と問題設定を揃える。訓練データの数をとして訓練データの集合をと置く。線形回帰モデルはとし、パラメータを学習する。コスト関数は正則化項を考慮せずに以下で表す。
もう少し詳しく書くと
このをで偏微分し勾配を求めると、
学習率をとおいて下2式のように同時に更新する。
正規化後の訓練データによる学習
訓練データ集合から計算される平均、分散を使ってに線形変換して正規化する。
このデータ集合を訓練データとして、別の線形回帰モデルのパラメータを推定する。またコスト関数は以下で表す。
すなわち
のでの偏微分はと同様に
学習率をとおいて下2式のように同時に更新する。
正規化後のモデルパラメータから正規化前への変換
正規化前後のパラメータ、コスト関数は異なるデータ集合、異なるパラメータから求めているので単純にそれぞれを比較できない。比較のためにはからへ、からへ再変換する。にとを代入して整理すると
もとの線形回帰モデルと比較して
が得られる。上2式をに代入して整理すると
すなわち
が求まる。
訓練データの生成と正規化
訓練データは以下の式で生成する。
は平均・分散の正規分布を表す。
ただし上述の正規化のための線形変換にはは使用せず、訓練データから計算する。
としてQiitaの訓練データ集合を似せた結果がこちら。左が生成したデータ、右が生成したデータに線形変換をして正規化したデータ。正規化後はゼロ付近に分散されている。
コスト関数は以下のように変わる。左が正規化前、右が正規化後のコスト関数。このコスト関数が最小になるように、凹地の底へ行くようにが更新されていく。正規化前は凹地の底が細長くなっているのに対して、正規化後はキレイな同心円を描いてコスト関数ができている。正規化前のパラメータ更新の挙動は、細長い凹地には入るがそこからは勾配は非常に緩いため一番底へはなかなか落ちていかない。一方で正規化後は、どの方向からも凹地の一番底へ真っ直ぐに落ちていく。
学習結果
初期パラメータを、学習率を一定にとして上述の学習則に沿って10,000回計算した。下図は計算したパラメータの結果をコスト関数を投影したコンター図上にプロットした結果を示している。左が正規化前、左が正規化後のプロットで、左図には正規化後から再変換したパラメータ更新の推移もプロットした。出発点Initial wを同じとして真値True wにいずれも近づいているが、挙動は異なることがわかる。
下図は横軸を反復回数の対数をとって、縦軸を正規化前の履歴と、正規化後から変換した履歴。下図中の左上がコスト関数の履歴、左下がの学習履歴、右下がの学習履歴。正規化前のコスト関数は10回程度で早く減少したものの、変動が大きくなっている。は上のコスト関数でいう横軸にあたり、勾配が緩やかなためなかなか収束していないが、動きにくいという意味なので変動は小さく安定している。は傾きが大きい=動きやすいため早く収束しているがその後も変動は大きいままになっている。この変動は学習率を下げることである程度小さくはなるが収束速度は遅くなる。
一方で、正規化後のコスト関数は収束が遅くなっているが、訓練データによる学習が一巡する1,000回あたりで変動は小さく安定して収束している。正規化後のは安定しているのに対しての変動が大きいことは気になる点。長くなっちゃったのでこれに関しては次回。
両者ともコスト関数の最小値は同じ値を示しており、パラメータの履歴と同じく、真値付近でウロウロしているのがここでも読み取れる。また、コスト関数はつまり傾きの影響が大きいことがグラフから読み取れ、数式や直感とも合っている。
おまけ、正規化後の履歴。同様に左上がコスト関数の履歴、左下がの推定履歴、右下がの推定履歴。2つのパラメータとも素直に真値へ収束している。