遊具

たくさんのおもちゃ

Maternカーネルで遊ぶ

はじめに

久々の更新となりました。最近、研究でガウス過程回帰を用いる用があり、MLPシリーズのガウス過程と機械学習で勉強していた所、Maternカーネルなる面白い物を見つけました。Maternカーネルはとあるパラメータ(後述)を変える事で、指数カーネルになったりRBFカーネルになったりするようです。非常に面白いですよね。今回はこのMaternカーネルなる不思議なカーネルで遊んでいきたいと思います*1*2カーネルとはなんぞ??みたいな話はしないので、ある程度カーネルの話について知っていると読みやすいと思います。

Maternカーネル

Maternカーネルの定義

まず初めに式を示します。

 \displaystyle k_{\nu} ( {\bf x, x^{'} } ) = \frac{ 2^{1-\nu} }{ \Gamma (\nu) } ( \frac{ \sqrt{2\nu}r }{\theta} )^{\nu} K_{\nu} ( \frac{ \sqrt{2\nu}r }{\theta} )

ここで、 \displaystyle K_{\nu} は第2種の変形ベッセル関数、 \displaystyle \thetaはスケールパラメータ、 \displaystyle r = | {\bf x - x^{'} } | です。

色々なMaternカーネル

定義に現れる \displaystyle \nuが関数の滑らかさと関わりがあり、 \displaystyle \nu = \frac{1}{2} の時は指数カーネル \displaystyle \nu = \inftyの時はRBFカーネルと一致します。他にも、Matern3, Matern5なども存在しているようです。ガウス過程と機械学習では、Matern3とMatern5の図が現れており、ああ確かにパラメータ変えれば滑らかになるんだなあという気持ちになります。

他のパラメータは使っちゃダメなの?

というのが今回の遊ぶポイントです。ある時は指数カーネル、またある時はRBFカーネルなど、パラメータによって滑らかさが変化してくれるのは良いけど、どうしてそのパラメータじゃなきゃダメなのでしょうか*3。また、パラメータを規則的に変化させると、滑らかさも規則的に変化するのでしょうか?

ちなみに、ガウス過程はGPyやscikit-learnなど、有名どころのパッケージを入れる事で簡単に遊ぶことができます。例えば、scikit-learnでMaternカーネルを調べると、次のような説明が載っています。

The parameter nu controlling the smoothness of the learned function.

(省略)

Note that values of nu not in [0.5, 1.5, 2.5, inf] incur a considerably higher computational cost (appr. 10 times higher) since they require to evaluate the modified Bessel function. Furthermore, in contrast to l, nu is kept fixed to its initial value and not optimized.

scikit-learn.org


指数カーネル、Mattern3, Mattern5、RBFカーネル以外を指定すると、やたらと計算量食うから気をつけろよとの記述があります。でもそんなことは関係ありません。早速遊んでみましょう。

実験

仮想データセットの作成

まずはデータセットの作成です。次の式からデータセットを作成します。

 y = |sin2x| + \epsilon,  \epsilon~N(0, 0.1^2)

f:id:seibibibi:20200504135732p:plain

実験では、上の式から30個のサンプルを取得し、このデータに対してガウス過程回帰を掛け、期待値と1σ範囲をプロットします。また、滑らかさを調節するパラメータ \displaystyle \nuは0.01から1.5まで0.01刻みで変化させることにします。

結果

結果はこんな感じになりました。赤が予測、黄色が1σ範囲です。

f:id:seibibibi:20200504134847g:plain

最初は全然データに対応できていない上に結構変化が激しいけど、時間が経つとなんだか変化がスローモーションっぽくなりますね。規則的にパラメータを変化させても滑らかさの変化は規則的ではないような気がしますね。加えて、Mattern0 ( \displaystyle \nu=0)みたいなカーネルがあったとすれば、きっと使い物にならないんだろうなあという事も副次的に分かりました。

終わりに

今回はMaternカーネルで遊んでみました。正直記事として耐える内容なのかそうでないのかだいぶ迷いましたが、GWなので(?)とりあえず書きました。ところでMaternのeの上に「ちょん」*4がついているんですけど、どうやってキーボード入力するんですかね。いちいちコピペするのも面倒だし…まあいいや

プログラム

一応載せておきます

# seibisi

from sklearn.gaussian_process import kernels
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from matplotlib import animation

def make(x):
    return np.abs(np.sin(2*x)) + np.random.normal(0, 0.1, len(x))

# 当てはめ,予測
def fit_pred(x, nu, df):
    # 当てはめ
    kern = kernels.Matern(nu=nu)
    clf = GaussianProcessRegressor(kernel=kern, alpha=1e-10, n_restarts_optimizer=30, normalize_y=True)
    clf.fit(df["x"].values.reshape(-1, 1), df["y"])
    # 予測
    exp, std = clf.predict(x.reshape(-1, 1), return_std = True)
    return exp, std
    
def main():
    # 1. データセットの作成
    x = np.linspace(-2, 2, 30)
    df = pd.DataFrame(np.transpose([x, make(x)]), columns=["x","y"])
    df = df.sort_values("x").reset_index(drop=True)
    # 2. 当てはめと予測
    nus = np.linspace(0.01, 1.5, 150)
    x_test = np.linspace(-3, 3, 200)
    exp = [None for i in range(len(nus))]
    std = [None for i in range(len(nus))]
    for i in range(len(nus)):
        exp[i], std[i] = fit_pred(x_test, nus[i], df)
    # 3. アニメーションで出力
    k = 1
    fig = plt.figure()
    axes = fig.add_subplot(111)
    ims = []
    for i in range(len(exp)):
        img = [axes.scatter(df["x"], df["y"], marker="x", color="purple")]
        img += axes.plot(x_test, exp[i], color="red")
        img += axes.plot(x_test, exp[i]-k*std[i], color="orange")
        img += axes.plot(x_test, exp[i]+k*std[i], color="orange")
        #img += axes.fill_between(x_test, exp[i]-k*std[i], exp[i]+k*std[i], color="orange", alpha=0.1)
        ims.append(img)
    ani = animation.ArtistAnimation(fig, ims, interval = 80)
    ani.save("ani.gif", writer="pillow")
    
if __name__ == "__main__":
    main()

*1:実は既に遊んでいてTwitterにも載せていたけど、敢えて記事にするほどの内容かどうかで迷っていてなんやかんやで時間が過ぎてしまった

*2:とりあえず書くことにした

*3:指数カーネルとRBFカーネルに一致する話については、ダメというよりも一致するから便利よ位の話だと思うので別に良いけど、Matern3とかMater5とか…

*4:ドイツ語のウムラウト的な…?