遊具

たくさんのおもちゃ

WAICで遊ぶ

1.はじめに

WAICはモデル選択のための評価基準の1つです。WAICは汎化損失を推定する量として非常に有用性の高い評価基準となっています。この辺りの理論や細かい話については、言わずと知れたベイズ統計の名著であるベイズ統計の理論と方法(渡辺ベイズ)や種々の記事に譲ること*1にして、本記事では実際に手を動かしてWAICで遊んでみたいと思います。ある程度渡辺ベイズを読まれていると、何をして遊んでいるのかが分かるかもしれません。

※第3章までは泣く泣く読んでいきましたが、第4章になってから途端に読めなくなってしまったので、とりあえず一旦整理のために手を動かしてみました…間違いや解釈違いは多々あると思います…

2.汎化損失とWAIC

さて、実際に遊んでいく前に、少しだけ汎化損失とWAIC周りの話を簡単に述べていきます。初めに、汎化損失 \displaystyle G_nは以下の式で定義されます。

 \displaystyle G_n = - \int q(x)logp^*(x)dx

汎化損失は真の分布 \displaystyle q(x)に対して、予測分布 \displaystyle p^*(x)がどの程度推測できているかを量的に示しています。

しかし、実際には真の分布 \displaystyle q(x)は未知であるため、正確な汎化損失を計算することはできません。従って、何らかの形で「推定」が必要になります。このような状況で、WAICは汎化損失の推定量として大きな役割を果たします。WAICは次の式で定義されます。

 \displaystyle W_n = T_n + \frac{\beta V_n}{n}

ここで、 \displaystyle T_nは経験損失、 \displaystyle V_n汎関数分散、 \displaystyle \betaは逆温度、 \displaystyle nはデータの数です。経験損失と汎関数分散については以下の式で計算されます。

 \displaystyle T_n = - \frac{1}{n} \sum_{i=1}^{n} log E_{\omega} [ p(X_i | \omega) ]

 \displaystyle V_n = \sum_{i=1}^{n} \{ E_{\omega} [ ( logp(X_i | \omega) )^2 ] - E_{\omega} [ logp(X_i | \omega) ]^2  \}

ここで、 \displaystyle \omegaは推定されたパラメータ、 \displaystyle E_{\omega} [] は推定された事後分布で期待値を取る操作になります。経験損失も汎関数分散も、パラメータさえ推定してしまえば計算できる量なので、WAICは手元にあるデータから完全に計算することができます。

加えて、汎化損失とWAICについて、次に示す非常に興味深い関係が存在しています。

 \displaystyle E[G_n] = E[W_n] + o(\frac{1}{n})

つまり、データを取ってはWAICを計算して、またデータを取ってはWAICを計算して…を繰り返して期待値を取ると、汎化損失の期待値に一致するというものです。汎化損失は未知の分布 \displaystyle q(x)が入っていることで実際に計算することができない量でしたが、平均的に考えればこの量を見事に推定できているという関係です。非常に面白い関係です。

WAICを利用してモデル選択を行う際は、WAICが小さいモデルを選択すると良いわけですが、WAICは確率変数であることを頭に入れておくことが大切です。


3.シミュレーション

というわけで、早速WAICで遊んでみましょう。

未知の分布 \displaystyle q(x)の作成

今回のシミュレーションでは、未知の分布を混合正規分布としました。具体的には、混合数を3(混合比= \displaystyle (0.2, 0.5, 0.3))とし、 \displaystyle N(0, 1), N(2, 1), N(4, 1)を混ぜています。この \displaystyle q(x)に対して、①通常の正規分布②混合数2の混合正規分布③混合数3の混合正規分布の3パターンを確率モデルとして仮定したときに、それぞれのWAICがどのようにふるまうのかを見ていきます。真の分布は次の図で示す通りです。

f:id:seibibibi:20200315115431p:plain

実験結果1 WAICのHello World

サンプリングするデータ数を50にした時のヒストグラムをはじめに示します。左側の赤いグラフが \displaystyle q(x)、右側が実際のデータです。

f:id:seibibibi:20200315115541p:plain

このデータに基づいて、予測分布 \displaystyle p^*(x)を求めていきます。今回はシミュレーションなので、すべての分散を既知(分散=1ですべての正規分布に渡って等分散)し、すべての平均の事前分布を \displaystyle N(0, 1)としておきました。

このデータに基づいて、作成された予測分布を次に示します。左から、通常の正規分布、混合数2、混合数3の予測分布です。

f:id:seibibibi:20200315115714p:plain

汎化損失はそれぞれ2.686、2.000、1.984となり、WAICはそれぞれ2.420、2.110、2.069となりました。汎化損失とWAICはだいたい近い値を取り、モデルは混合数3の混合正規分布を選ぶと良さそうです*2

実験結果2 汎化損失とWAICの分布

汎化損失とWAICの期待値を実際に比較してみましょう。サンプリングするデータ数と事前分布を実験結果1と同じにしておき、混合数2の混合正規分布を確率モデルとして用いた*3時の汎化損失とWAICのヒストグラムを以下に示します。ここで、サンプリングからWAIC、汎化損失の計算は100回行いました。

f:id:seibibibi:20200315164357p:plain

それぞれの平均は汎化損失が2.004、WAICが2.044でした。ほとんど一致しています。すごい。

実験結果3 サンプルサイズを変更したときのWAICの分布

計算回数、確率モデル、事前分布は実験結果2と同じに揃え、データ数を50, 100, 150と変えた時のWAICの分布を示します。

f:id:seibibibi:20200315113145p:plain

データ数が増えるほど、ある1点に集中する傾向があります。面白いなあ。


4.まとめ

本記事ではWAICを使って遊んでみました。やはり数式を追うだけではなく実際に手を動かしてどうなるのかを確かめるのは楽しいですね。非常に勉強になりました。正則理論までしか読めていませんが、数学をちょこちょこ勉強して、一般理論以降の章を読めるように頑張っていきたいところです。


プログラム

WAICの実装自体はRで行い、次の記事を参考にさせていただきました。
statmodeling.hatenablog.com

パラメータのサンプリング自体はStanを利用しました。Stanのプログラムは以下になります。基本的にはStan User's Guideを参考にしましたが、その通りの実装してみると案の定上手く行かなかった*4(あるある)ので、分散を既知にした状態でプログラムを書きました。Stan User's Guideは以下になります。

https://mc-stan.org/docs/2_19/stan-users-guide-2_19.pdf

data{
  int<lower=1> K;
  int<lower=1> N;
  real y[N];
}
parameters{
  simplex[K] pi;
  ordered[K] mu;
}
model{
  vector[K] log_pi = log(pi);
  mu ~ normal(0, 1);
  for(n in 1:N){
    vector[K] lps = log_pi;
    for(k in 1:K)
      lps[k] += normal_lpdf(y[n] | mu[k], 1);
    target += log_sum_exp(lps);
  }
}
generated quantities{
  vector[N] log_likelihood;
  int index;
  real y_pred;
  for(n in 1:N){
    vector[K] lp;
    for(k in 1:K)
      lp[k] = log(pi[k]) + normal_lpdf(y[n] | mu[k], 1);
    log_likelihood[n] = log_sum_exp(lp);
  }
  index = categorical_rng(pi);
  y_pred = normal_rng(mu[categorical_rng(pi)], 1);
}

*1:譲るとは言ったものの、渡辺ベイズの中で展開される数式を追い切れていないので、完全に丸投げの形になってしまいますが…

*2:ちなみに、WAICの差がどの程度本質的なのかは僕はよくわかっていません

*3:混合数3を採用すると遅いので混合数2にしました。ただせっかちなだけです…

*4:具体的には、収束に関する警告がどっさり出てきます。ここの問題はまだ未解決なので色々調べてみたいと思います。