Using cmdstanr in SimDesign

Programming
Statistics
Author

Mark Lai

Published

July 14, 2021

Modified

November 3, 2023

library(SimDesign)
library(cmdstanr)

[Update: Use parallel computing with two cores.]

Adapted from https://cran.r-project.org/web/packages/SimDesign/vignettes/SimDesign-intro.html

See https://mc-stan.org/cmdstanr/articles/cmdstanr.html for using cmdstanr

Design <- createDesign(sample_size = c(30, 60, 120, 240), 
                       distribution = c('norm', 'chi'))
Design
# A tibble: 8 × 2
  sample_size distribution
        <dbl> <chr>       
1          30 norm        
2          60 norm        
3         120 norm        
4         240 norm        
5          30 chi         
6          60 chi         
7         120 chi         
8         240 chi         
Generate <- function(condition, fixed_objects = NULL) {
    N <- condition$sample_size
    dist <- condition$distribution
    if(dist == 'norm'){
        dat <- rnorm(N, mean = 3)
    } else if(dist == 'chi'){
        dat <- rchisq(N, df = 3)
    }
    dat
}

Define Bayes estimator of the mean with STAN

# STAN model
bmean_stan <- "
    data {
        int<lower=0> N;
        array[N] real x;
    }
    parameters {
        real mu;
        real<lower=0> sigma;
    }
    model {
        target += normal_lpdf(mu | 0, 10);  // weakly informative prior
        target += normal_lpdf(x | mu, sigma);
    }
"
# Save file
stan_path <- write_stan_file(bmean_stan)
mod <- cmdstan_model(stan_path)
Analyse <- function(condition, dat, fixed_objects = NULL) {
    mod <- fixed_objects$mod
    M0 <- mean(dat)
    M1 <- mean(dat, trim = .1)
    M2 <- mean(dat, trim = .2)
    med <- median(dat)
    stan_fit <- quiet(mod$sample(list(x = dat, N = length(dat)),
                                 refresh = 0, chains = 1, 
                                 show_messages = FALSE))
    MB <- stan_fit$summary("mu", mean)$mean[1]
    ret <- c(mean_no_trim = M0, mean_trim.1 = M1, 
             mean_trim.2 = M2, median = med, 
             bayes_mean = MB)
    ret
}
Summarise <- function(condition, results, fixed_objects = NULL) {
    obs_bias <- bias(results, parameter = 3)
    obs_RMSE <- RMSE(results, parameter = 3)
    ret <- c(bias = obs_bias, RMSE = obs_RMSE, RE = RE(obs_RMSE))
    ret
}
res <- runSimulation(Design, replications = 50, generate = Generate, 
                     analyse = Analyse, summarise = Summarise, 
                     parallel = TRUE,
                     ncores = min(2, parallel::detectCores()), 
                     fixed_objects = list(mod = mod),
                     packages = "cmdstanr")

Number of parallel clusters in use: 2


Design row: 1/8;   RAM used: 68.2 Mb;   Total elapsed time: 0.00s 
 Conditions: sample_size=30, distribution=norm


Design row: 2/8;   RAM used: 69 Mb;   Total elapsed time: 5.22s 
 Conditions: sample_size=60, distribution=norm


Design row: 3/8;   RAM used: 69 Mb;   Total elapsed time: 10.26s 
 Conditions: sample_size=120, distribution=norm


Design row: 4/8;   RAM used: 69 Mb;   Total elapsed time: 15.12s 
 Conditions: sample_size=240, distribution=norm


Design row: 5/8;   RAM used: 69 Mb;   Total elapsed time: 20.10s 
 Conditions: sample_size=30, distribution=chi


Design row: 6/8;   RAM used: 69 Mb;   Total elapsed time: 25.06s 
 Conditions: sample_size=60, distribution=chi


Design row: 7/8;   RAM used: 69 Mb;   Total elapsed time: 29.79s 
 Conditions: sample_size=120, distribution=chi


Design row: 8/8;   RAM used: 69 Mb;   Total elapsed time: 34.88s 
 Conditions: sample_size=240, distribution=chi

Simulation complete. Total execution time: 39.69s
knitr::kable(res)
sample_size distribution bias.mean_no_trim bias.mean_trim.1 bias.mean_trim.2 bias.median bias.bayes_mean RMSE.mean_no_trim RMSE.mean_trim.1 RMSE.mean_trim.2 RMSE.median RMSE.bayes_mean RE.mean_no_trim RE.mean_trim.1 RE.mean_trim.2 RE.median RE.bayes_mean REPLICATIONS SIM_TIME RAM_USED SEED COMPLETED
30 norm 0.0137592 0.0072189 0.0051610 -0.0088380 0.0134503 0.1831414 0.2016381 0.2068831 0.2458775 0.1837847 1 1.212194 1.276078 1.802455 1.0070379 50 5.223 69 Mb 77254733 Thu Mar 21 10:32:12 2024
60 norm -0.0161554 -0.0167838 -0.0165226 -0.0155586 -0.0181256 0.1171312 0.1192122 0.1279795 0.1617023 0.1176554 1 1.035847 1.193811 1.905843 1.0089697 50 5.041 69 Mb 745175193 Thu Mar 21 10:32:17 2024
120 norm -0.0135750 -0.0129037 -0.0129281 -0.0094663 -0.0135980 0.0763685 0.0843901 0.0888738 0.1060062 0.0766066 1 1.221109 1.354312 1.926789 1.0062448 50 4.852 69 Mb 796799714 Thu Mar 21 10:32:22 2024
240 norm 0.0112872 0.0172670 0.0194530 0.0195026 0.0113646 0.0661500 0.0699354 0.0717609 0.0786083 0.0664199 1 1.117724 1.176837 1.412139 1.0081765 50 4.982 69 Mb 1458160184 Thu Mar 21 10:32:27 2024
30 chi 0.0141920 -0.3026217 -0.4385604 -0.5725934 0.0047972 0.4931700 0.5686208 0.6421671 0.7338760 0.4913024 1 1.329389 1.695520 2.214380 0.9924404 50 4.962 69 Mb 153258056 Thu Mar 21 10:32:32 2024
60 chi 0.0219076 -0.3127238 -0.4623196 -0.6273301 0.0210254 0.2738739 0.4046596 0.5340195 0.6891814 0.2767046 1 2.183125 3.802007 6.332363 1.0207785 50 4.729 69 Mb 838557643 Thu Mar 21 10:32:37 2024
120 chi -0.0170950 -0.3392576 -0.4702292 -0.5832027 -0.0186802 0.2179343 0.3971243 0.5138743 0.6278862 0.2178307 1 3.320487 5.559843 8.300623 0.9990498 50 5.093 69 Mb 1973331824 Thu Mar 21 10:32:42 2024
240 chi 0.0242644 -0.3314124 -0.4775029 -0.6166151 0.0238784 0.1424573 0.3627346 0.5014147 0.6403622 0.1439531 1 6.483486 12.388673 20.206075 1.0211110 50 4.808 69 Mb 1726109243 Thu Mar 21 10:32:47 2024