Predicting new observations with a previously fitted BART model
Predicting new observations with a previously fitted BART model
BART is a Bayesian sum-of-trees model.
For a numeric response y, we have y=f(x)+e, where eN(0,sigma2).
f is the sum of many tree models. The goal is to have very flexible inference for the uknown function f.
In the spirit of ensemble models , each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.
## S3 method for class 'crisk2bart'predict(object, newdata, newdata2, mc.cores=1, openmp=(mc.cores.openmp()>0),...)
Arguments
object: object returned from previous BART fit with crisk2.bart
or mc.crisk2.bart.
newdata: Matrix of covariates to predict the distribution of t1.
newdata2: Matrix of covariates to predict the distribution of t2.
mc.cores: Number of threads to utilize.
openmp: Logical value dictating whether OpenMP is utilized for parallel processing. Of course, this depends on whether OpenMP is available on your system which, by default, is verified with mc.cores.openmp.
...: Other arguments which will be passed on to pwbart.
Details
BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior (f,sigma)∥(x,y) in the numeric y case and just f in the binary y case.
Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values f∗(x) (and sigma∗ in the numeric case) where * denotes a particular draw. The x is either a row from the training data (x.train) or the test data (x.test).
Returns
Returns an object of type crisk2bart with predictions corresponding to newdata and newdata2.
data(transplant)delta <-(as.numeric(transplant$event)-1)## recode so that delta=1 is cause of interest; delta=2 otherwisedelta[delta==1]<-4delta[delta==2]<-1delta[delta>1]<-2table(delta, transplant$event)times <- pmax(1, ceiling(transplant$futime/7))## weeks##times <- pmax(1, ceiling(transplant$futime/30.5)) ## monthstable(times)typeO <-1*(transplant$abo=='O')typeA <-1*(transplant$abo=='A')typeB <-1*(transplant$abo=='B')typeAB <-1*(transplant$abo=='AB')table(typeA, typeO)x.train <- cbind(typeO, typeA, typeB, typeAB)x.test <- cbind(1,0,0,0)dimnames(x.test)[[2]]<- dimnames(x.train)[[2]]## parallel::mcparallel/mccollect do not exist on windowsif(.Platform$OS.type=='unix'){##test BART with token run to ensure installation works post <- mc.crisk2.bart(x.train=x.train, times=times, delta=delta, seed=99, mc.cores=2, nskip=5, ndpost=5, keepevery=1) pre <- surv.pre.bart(x.train=x.train, x.test=x.test, times=times, delta=delta) K <- post$K
pred <- mc.crisk2.pwbart(pre$tx.test, pre$tx.test, post$treedraws, post$treedraws2, post$binaryOffset, post$binaryOffset2)}## Not run:## run one long MCMC chain in one process## set.seed(99)## post <- crisk2.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)## in the interest of time, consider speeding it up by parallel processing## run "mc.cores" number of shorter MCMC chains in parallel processespost <- mc.crisk2.bart(x.train=x.train, times=times, delta=delta, x.test=x.test, seed=99, mc.cores=8)## check <- mc.crisk2.pwbart(post$tx.test, post$tx.test,## post$treedraws, post$treedraws2,## post$binaryOffset,## post$binaryOffset2, mc.cores=8)check <- predict(post, newdata=post$tx.test, newdata2=post$tx.test2, mc.cores=8)print(c(post$surv.test.mean[1], check$surv.test.mean[1], post$surv.test.mean[1]-check$surv.test.mean[1]), digits=22)print(all(round(post$surv.test.mean, digits=9)== round(check$surv.test.mean, digits=9)))print(c(post$cif.test.mean[1], check$cif.test.mean[1], post$cif.test.mean[1]-check$cif.test.mean[1]), digits=22)print(all(round(post$cif.test.mean, digits=9)== round(check$cif.test.mean, digits=9)))print(c(post$cif.test2.mean[1], check$cif.test2.mean[1], post$cif.test2.mean[1]-check$cif.test2.mean[1]), digits=22)print(all(round(post$cif.test2.mean, digits=9)== round(check$cif.test2.mean, digits=9)))## End(Not run)