回帰計算のためのデータ事前処理

Intro

確率的勾配降下法のサンプルをネットで調べていたところ、検索結果のトップ群にQiitaの記事になかなか収束しないという話があったが、もろもろ調べて考察した結果、事前処理としてデータの正規化をすれば収束性が改善するらしい。ここでは事前処理のありなしで収束性がどう変わるか比較する。

Wikipediaの確率的勾配降下法のページの平均・分散の調整セクションに、”訓練データの正規化をして、平均0分散1になるように調整する方が良い”とある。これは、あまりにも大きさの異なるデータがあるとコスト関数の形状が潰れてしまい収束性が悪くなってしまうため、整形する目的で訓練データを定数で加減乗除して収束性の改善を図る。この定数に平均や分散を使う。

問題設定・正規化前の学習則

Qiitaの記事と問題設定を揃える。訓練データの数をmとして訓練データの集合を(x^{(1)}, y^{(1)}), (x^{(2)}, y^{(2)}), \cdots, (x^{(m)}, y^{(m)})と置く。線形回帰モデルはy=h_w(x)=w_0 + w_1 xとし、パラメータw = [w_0 , w_1]を学習する。コスト関数J(w)は正則化項を考慮せずに以下で表す。

\displaystyle J_i(w) = \frac{1}{2} \left\{ h_w(x^{(i)}) - y^{(i)} \right\}^2
\displaystyle J(w) = \frac{1}{2} \sum_{i=1}^m J_i(w)

もう少し詳しく書くと

\displaystyle J(w) = \frac{1}{2} \sum_{i=1}^m \left\{ w_0 + w_1 x^{(i)} - y^{(i)}\right\}^2

このJ_i(w)w_0, w_1で偏微分し勾配を求めると、

\displaystyle \frac{\partial J_i(w)}{\partial w_0} = \left(w_0 + w_1 x^{(i)} - y^{(i)}\right)
\displaystyle \frac{\partial J_i(w)}{\partial w_1} = \left(w_0 + w_1 x^{(i)} - y^{(i)}\right)x^{(i)}

学習率を\etaとおいて下2式のように同時に更新する。

\displaystyle w_0 := w_0 - \eta \frac{\partial J_i(w)}{\partial w_0}
\displaystyle w_1 := w_1 - \eta \frac{\partial J_i(w)}{\partial w_1}

正規化後の訓練データによる学習

訓練データ集合(x^{(1)}, y^{(1)}), (x^{(2)}, y^{(2)}), \cdots, (x^{(m)}, y^{(m)})から計算される平均\mu_x, \mu_y、分散\sigma_x^2, \sigma_y^2を使ってX^{(i)}, Y^{(i)}に線形変換して正規化する。

\displaystyle X^{(i)} = \frac{x^{(i)}-\mu_x}{\sigma_x}
\displaystyle Y^{(i)} = \frac{y^{(i)}-\mu_y}{\sigma_y}

このデータ集合を訓練データとして、別の線形回帰モデルY = H_W(X) = W_0 + W_1 XのパラメータW=[W_0, W_1]^Tを推定する。またコスト関数J_i'(W), J'(W)は以下で表す。

\displaystyle J'_i(W) = \frac{1}{2} \left\{ H_W(X^{(i)}) - Y^{(i)}\right\}^2
\displaystyle J'(W) = \frac{1}{2} \sum_{i=1}^m J'_i(W)

すなわち

\displaystyle J'(W) = \frac{1}{2} \sum_{i=1}^m \left\{ W_0 + W_1 X^{(i)} - Y^{(i)}\right\}^2

J'_i(W)W_0, W_1での偏微分はJ(w)と同様に

\displaystyle \frac{\partial J_i(W)}{\partial W_0} = \left(W_0 + W_1 X^{(i)} - Y^{(i)} \right)
\displaystyle \frac{\partial J_i(W)}{\partial W_1} = \left(W_0 + W_1 X^{(i)} - Y^{(i)} \right) X^{(i)}

学習率を\eta'とおいて下2式のように同時に更新する。

\displaystyle W_0 = W_0 - \eta' \frac{\partial J'_i(W)}{\partial W_0}
\displaystyle W_1 = W_1 - \eta' \frac{\partial J'_i(W)}{\partial W_1}

正規化後のモデルパラメータから正規化前への変換

正規化前後のパラメータw, W、コスト関数J(x), J'(X)は異なるデータ集合、異なるパラメータから求めているので単純にそれぞれを比較できない。比較のためにはWからwへ、J'(X)からJ(x)へ再変換する。Y = H_W(X) = W_0 + W_1 X\displaystyle X = \frac{x-\mu_x}{\sigma_x}\displaystyle Y = \frac{y-\mu_y}{\sigma_y}を代入して整理すると

\displaystyle \frac{y-\mu_y}{\sigma_y}= W_0 + W_1 \frac{x-\mu_x}{\sigma_x}
\displaystyle \leftrightarrow y =\left \{ \sigma_y W_0 - \frac{\sigma_y}{\sigma_x} \mu_x W_1 + \mu_y \right \} + \frac{\sigma_y}{\sigma_x} W_1 x

もとの線形回帰モデルy=h_w(x)=w_0 + w_1 xと比較して

\displaystyle w_0 = \sigma_y W_0 - \frac{\sigma_y}{\sigma_x} \mu_x W_1 + \mu_y
\displaystyle w_1 = \frac{\sigma_y}{\sigma_x} W_1

が得られる。上2式をJ_i(w)に代入して整理すると

\displaystyle J_i(w) = \sigma_y^2 \frac{1}{2} \left\{ W_0 + W_1 X^{(i)} - Y^{(i)} \right\}^2
\leftrightarrow \displaystyle J_i(w) = \sigma_y^2 J'_i(W)

すなわち

\displaystyle J(w) = \sigma_y^2 J'(W)

が求まる。

訓練データの生成と正規化

訓練データは以下の式で生成する。

\left( x^{(i)}, y^{(i)} \right) = \left( \mathcal{N}(\tilde{\mu}_x, \tilde{\sigma}_x^2), \mathcal{N}(\tilde{w}_0+\tilde{w}_1x^{(i)}, \tilde{\sigma}_y^2) \right)

\displaystyle \mathcal{N}(\mu, \sigma^2)は平均\mu・分散\sigma^2の正規分布を表す。

\displaystyle f_{\mu, \sigma^2}(x) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp \left( -\frac{(x-\mu)^2}{2\sigma^2} \right)

ただし上述の正規化のための線形変換には\tilde{\mu}_x, \tilde{\sigma}_x, \tilde{\sigma}_yは使用せず、訓練データから計算する。

m=1000
\tilde{\mu}_x=10
\tilde{\sigma}_x^2=4
\tilde{\sigma}_y^2=2
\tilde{w}_0=0
\tilde{w}_1=1

としてQiitaの訓練データ集合を似せた結果がこちら。左が生成したデータ、右が生成したデータに線形変換をして正規化したデータ。正規化後はゼロ付近に分散されている。
training_data

コスト関数は以下のように変わる。左が正規化前、右が正規化後のコスト関数。このコスト関数が最小になるように、凹地の底へ行くようにw_0, w_1, W_0, W_1が更新されていく。正規化前は凹地の底が細長くなっているのに対して、正規化後はキレイな同心円を描いてコスト関数ができている。正規化前のパラメータ更新の挙動は、細長い凹地には入るがそこからは勾配は非常に緩いため一番底へはなかなか落ちていかない。一方で正規化後は、どの方向からも凹地の一番底へ真っ直ぐに落ちていく。
cost_functions_3d

学習結果

初期パラメータをw_0=-1.5, w_1=-1.5、学習率を一定に\eta=0.006, \eta'=0.006として上述の学習則に沿って10,000回計算した。下図は計算したパラメータの結果をコスト関数を投影したコンター図上にプロットした結果を示している。左が正規化前、左が正規化後のプロットで、左図には正規化後から再変換したパラメータ更新の推移もプロットした。出発点Initial wを同じとして真値True wにいずれも近づいているが、挙動は異なることがわかる。
cost_functions_2d下図は横軸を反復回数の対数をとって、縦軸を正規化前の履歴と、正規化後から変換した履歴。下図中の左上がコスト関数の履歴、左下がw_0の学習履歴、右下がw_1の学習履歴。正規化前のコスト関数は10回程度で早く減少したものの、変動が大きくなっている。w_0は上のコスト関数でいう横軸にあたり、勾配が緩やかなためなかなか収束していないが、動きにくいという意味なので変動は小さく安定している。w_1は傾きが大きい=動きやすいため早く収束しているがその後も変動は大きいままになっている。この変動は学習率を下げることである程度小さくはなるが収束速度は遅くなる。

一方で、正規化後のコスト関数は収束が遅くなっているが、訓練データによる学習が一巡する1,000回あたりで変動は小さく安定して収束している。正規化後のw_1は安定しているのに対してw_0の変動が大きいことは気になる点。長くなっちゃったのでこれに関しては次回。

両者ともコスト関数の最小値は同じ値を示しており、パラメータの履歴と同じく、真値付近でウロウロしているのがここでも読み取れる。また、コスト関数はw_1つまり傾きの影響が大きいことがグラフから読み取れ、数式や直感とも合っている。
convergence_historyおまけ、正規化後の履歴。同様に左上がコスト関数の履歴、左下がW_0の推定履歴、右下がW_1の推定履歴。2つのパラメータW_0, W_1とも素直に真値へ収束している。
convergence_history_normalized_domain