ベイズ統計で重回帰分析

Table of Content

ベイズ統計における重回帰分析

頻度主義とベイズ統計のそれぞれで重回帰分析

  • これからベイス統計をやってみたい人
  • StanとRを使って何か始めたい人
    向けです。

参考にした本

頻度主義による重回帰分析

前回と同じように、Rのlmを用いて回帰分析を行う。ここではmtcarsdispmpgcylでフィッティングしてみる。

data <- mtcars
library(ggplot2)
library(GGally)

data$cyl <- as.factor(data$cyl)
p <- ggplot(data, aes(x=mpg, y=disp, group=cyl)) + 
     geom_point(aes(color=cyl), shape=1, size=5)

res_lm <- lm(disp ~ mpg + cyl, data)
summary(res_lm)

Call:
lm(formula = disp ~ mpg + cyl, data = data)

Residuals:
    Min      1Q  Median      3Q     Max 
-76.599 -25.487  -8.882  18.925  86.296 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)  291.955     73.966   3.947 0.000484 ***
mpg           -7.006      2.722  -2.574 0.015638 *  
cyl6          29.688     29.610   1.003 0.324629    
cyl8         166.943     36.786   4.538 9.79e-05 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 47.25 on 28 degrees of freedom
Multiple R-squared:  0.8687,    Adjusted R-squared:  0.8547 
F-statistic: 61.77 on 3 and 28 DF,  p-value: 1.83e-12

ベイズ統計による重回帰分析

Stanによるモデルファイルを下記のように書く。今回もパラメータの事前分布は、無情報事前分布として指定しない。

data {
    int N;
    real<lower=0> Mpg[N];
    int<lower=4, upper=8> Cyl[N];
    real Y[N];
}

parameters {
    real b1;
    real b2;
    real b3;
    real<lower=0> sigma;
}

transformed parameters {
    real mu[N];
    for (n in 1:N)
        mu[n] = b1 + b2*Mpg[n] + b3*Cyl[n];
}

model {
    for (n in 1:N)
        Y[n] ~ normal(mu[n], sigma);
}

generated quantities {
    real y_pred[N];
    for (n in 1:N) {
        y_pred[n] = normal_rng(mu[n], sigma);
    }
}

次に実行する。

data <- list(N=nrow(d), Mpg=d$mpg, Cyl=d$cyl, Y=d$disp)
stanmodel <- stan_model(file='model/stan_model_gen.stan')
fit <- sampling(
+    stanmodel,
+    data=data,
+    init=function() {
+        list(a=runif(1,-10,10), b=runif(1,0,10), sigma=10)
+    },
+    seed=1234,
+    chains = 4, iter = 1000, warmup=200, thin=2
+ )

summary(fit)で各パラメータの平均・標準偏差・quantileが確認できる。

summary(fit)$summary[1:4,]
           mean   se_mean         sd       2.5%        25%       50%        75%
b1    64.913028 4.3341966 123.118247 -176.36942 -16.905147 65.659685 148.789162
b2    -5.890881 0.1066988   3.088508  -11.85156  -7.918141 -5.899063  -3.902817
b3    45.909303 0.3692127  10.593797   25.18512  38.811023 46.056718  52.962120
sigma 54.579099 0.2463119   7.728447   42.09871  49.306562 53.784017  58.859763
            97.5%    n_eff     Rhat
b1    303.4839442 806.9147 1.000391
b2      0.2014996 837.8739 1.000653
b3     67.0406470 823.2849 1.000503
sigma  72.0140416 984.4953 1.001003

グラフにプロットする。

ms <- rstan::extract(fit)

data.frame.quantile.mcmc <- function(x1, x2, y_mcmc, probs=c(2.5, 25, 50, 75, 97.5)/100) {
    qua <- apply(y_mcmc, 2, quantile, probs=probs)
    d <- data.frame(X1=x1, X2=x2, t(qua))
    colnames(d) <- c('mpg', 'cyl', paste0('p', probs*100))
    return(d)
}

ggplot.5quantile <- function(data) {
    p <- ggplot(data=data)
    p <- p + theme_bw(base_size=18)
    p <- p + geom_ribbon(aes(x=mpg, ymin=p2.5, ymax=p97.5, fill=cyl), alpha=1/6)
    p <- p + geom_ribbon(aes(x=mpg, ymin=p25, ymax=p75, fill=cyl), alpha=2/6)
    p <- p + geom_line(aes(x=mpg, y=p50, color=cyl), size=1)
    return(p)
}

d_est <- data.frame.quantile.mcmc(x1=d$mpg, x2=d$cyl, y_mcmc=ms$y_pred)
d_est$cyl <- as.factor(d_est$cyl)
p <- ggplot.5quantile(data=d_est)
d$cyl <- as.factor(d$cyl)
p <- p + geom_point(data=d, aes(x=mpg, y=disp, color=cyl), shape=1, size=5)
p

単回帰分析の場合は、こちらをどうぞ

コメントを残す

Scroll to top