Table of Content
ベイズ統計における重回帰分析
頻度主義とベイズ統計のそれぞれで重回帰分析
- これからベイス統計をやってみたい人
- StanとRを使って何か始めたい人
向けです。
参考にした本
-
参考リンク
ベイズ統計で単回帰分析
頻度主義による重回帰分析
前回と同じように、Rのlm
を用いて回帰分析を行う。ここではmtcars
のdisp
をmpg
とcyl
でフィッティングしてみる。
data <- mtcars
library(ggplot2)
library(GGally)
data$cyl <- as.factor(data$cyl)
p <- ggplot(data, aes(x=mpg, y=disp, group=cyl)) +
geom_point(aes(color=cyl), shape=1, size=5)
res_lm <- lm(disp ~ mpg + cyl, data)
summary(res_lm)
Call:
lm(formula = disp ~ mpg + cyl, data = data)
Residuals:
Min 1Q Median 3Q Max
-76.599 -25.487 -8.882 18.925 86.296
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 291.955 73.966 3.947 0.000484 ***
mpg -7.006 2.722 -2.574 0.015638 *
cyl6 29.688 29.610 1.003 0.324629
cyl8 166.943 36.786 4.538 9.79e-05 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 47.25 on 28 degrees of freedom
Multiple R-squared: 0.8687, Adjusted R-squared: 0.8547
F-statistic: 61.77 on 3 and 28 DF, p-value: 1.83e-12
ベイズ統計による重回帰分析
Stanによるモデルファイルを下記のように書く。今回もパラメータの事前分布は、無情報事前分布として指定しない。
data {
int N;
real<lower=0> Mpg[N];
int<lower=4, upper=8> Cyl[N];
real Y[N];
}
parameters {
real b1;
real b2;
real b3;
real<lower=0> sigma;
}
transformed parameters {
real mu[N];
for (n in 1:N)
mu[n] = b1 + b2*Mpg[n] + b3*Cyl[n];
}
model {
for (n in 1:N)
Y[n] ~ normal(mu[n], sigma);
}
generated quantities {
real y_pred[N];
for (n in 1:N) {
y_pred[n] = normal_rng(mu[n], sigma);
}
}
次に実行する。
data <- list(N=nrow(d), Mpg=d$mpg, Cyl=d$cyl, Y=d$disp)
stanmodel <- stan_model(file='model/stan_model_gen.stan')
fit <- sampling(
+ stanmodel,
+ data=data,
+ init=function() {
+ list(a=runif(1,-10,10), b=runif(1,0,10), sigma=10)
+ },
+ seed=1234,
+ chains = 4, iter = 1000, warmup=200, thin=2
+ )
summary(fit)で各パラメータの平均・標準偏差・quantileが確認できる。
summary(fit)$summary[1:4,]
mean se_mean sd 2.5% 25% 50% 75%
b1 64.913028 4.3341966 123.118247 -176.36942 -16.905147 65.659685 148.789162
b2 -5.890881 0.1066988 3.088508 -11.85156 -7.918141 -5.899063 -3.902817
b3 45.909303 0.3692127 10.593797 25.18512 38.811023 46.056718 52.962120
sigma 54.579099 0.2463119 7.728447 42.09871 49.306562 53.784017 58.859763
97.5% n_eff Rhat
b1 303.4839442 806.9147 1.000391
b2 0.2014996 837.8739 1.000653
b3 67.0406470 823.2849 1.000503
sigma 72.0140416 984.4953 1.001003
グラフにプロットする。
ms <- rstan::extract(fit)
data.frame.quantile.mcmc <- function(x1, x2, y_mcmc, probs=c(2.5, 25, 50, 75, 97.5)/100) {
qua <- apply(y_mcmc, 2, quantile, probs=probs)
d <- data.frame(X1=x1, X2=x2, t(qua))
colnames(d) <- c('mpg', 'cyl', paste0('p', probs*100))
return(d)
}
ggplot.5quantile <- function(data) {
p <- ggplot(data=data)
p <- p + theme_bw(base_size=18)
p <- p + geom_ribbon(aes(x=mpg, ymin=p2.5, ymax=p97.5, fill=cyl), alpha=1/6)
p <- p + geom_ribbon(aes(x=mpg, ymin=p25, ymax=p75, fill=cyl), alpha=2/6)
p <- p + geom_line(aes(x=mpg, y=p50, color=cyl), size=1)
return(p)
}
d_est <- data.frame.quantile.mcmc(x1=d$mpg, x2=d$cyl, y_mcmc=ms$y_pred)
d_est$cyl <- as.factor(d_est$cyl)
p <- ggplot.5quantile(data=d_est)
d$cyl <- as.factor(d$cyl)
p <- p + geom_point(data=d, aes(x=mpg, y=disp, color=cyl), shape=1, size=5)
p
単回帰分析の場合は、こちらをどうぞ