Predicting new observations with a previously fitted BART model
Predicting new observations with a previously fitted BART model
BART is a Bayesian sum-of-trees model.
For a numeric response y, we have y=f(x)+e, where eN(0,sigma2).
f is the sum of many tree models. The goal is to have very flexible inference for the uknown function f.
In the spirit of ensemble models , each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.
## S3 method for class 'survbart'predict(object, newdata, mc.cores=1, openmp=(mc.cores.openmp()>0),...)
Arguments
object: object returned from previous BART fit with surv.bart
or mc.surv.bart.
newdata: Matrix of covariates to predict the distribution of t.
mc.cores: Number of threads to utilize.
openmp: Logical value dictating whether OpenMP is utilized for parallel processing. Of course, this depends on whether OpenMP is available on your system which, by default, is verified with mc.cores.openmp.
...: Other arguments which will be passed on to pwbart.
Details
BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior (f,sigma)∥(x,y) in the numeric y case and just f in the binary y case.
Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values f∗(x) (and sigma∗ in the numeric case) where * denotes a particular draw. The x is either a row from the training data (x.train) or the test data (x.test).
Returns
Returns an object of type survbart with predictions corresponding to newdata.
## load the advanced lung cancer exampledata(lung)group <--which(is.na(lung[,7]))## remove missing row for ph.karnotimes <- lung[group,2]##lung$timedelta <- lung[group,3]-1##lung$status: 1=censored, 2=dead##delta: 0=censored, 1=dead## this study reports time in days rather than months like other studies## coarsening from days to months will reduce the computational burdentimes <- ceiling(times/30)summary(times)table(delta)x.train <- as.matrix(lung[group, c(4,5,7)])## matrix of observed covariates## lung$age: Age in years## lung$sex: Male=1 Female=2## lung$ph.karno: Karnofsky performance score (dead=0:normal=100:by=10)## rated by physiciandimnames(x.train)[[2]]<- c('age(yr)','M(1):F(2)','ph.karno(0:100:10)')summary(x.train[,1])table(x.train[,2])table(x.train[,3])x.test <- matrix(nrow=84, ncol=3)## matrix of covariate scenariosdimnames(x.test)[[2]]<- dimnames(x.train)[[2]]i <-1for(age in5*(9:15))for(sex in1:2)for(ph.karno in10*(5:10)){ x.test[i,]<- c(age, sex, ph.karno) i <- i+1}## this x.test is relatively small, but often you will want to## predict for a large x.test matrix which may cause problems## due to consumption of RAM so we can predict separately## mcparallel/mccollect do not exist on windowsif(.Platform$OS.type=='unix'){##test BART with token run to ensure installation works set.seed(99) post <- surv.bart(x.train=x.train, times=times, delta=delta, nskip=5, ndpost=5, keepevery=1) pre <- surv.pre.bart(x.train=x.train, times=times, delta=delta, x.test=x.test) pred <- predict(post, pre$tx.test)##pred. <- surv.pwbart(pre$tx.test, post$treedraws, post$binaryOffset)}## Not run:## run one long MCMC chain in one processset.seed(99)post <- surv.bart(x.train=x.train, times=times, delta=delta)## run "mc.cores" number of shorter MCMC chains in parallel processes## post <- mc.surv.bart(x.train=x.train, times=times, delta=delta,## mc.cores=5, seed=99)pre <- surv.pre.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)pred <- predict(post, pre$tx.test)## let's look at some survival curves## first, a younger group with a healthier KPS## age 50 with KPS=90: males and females## males: row 17, females: row 23x.test[c(17,23),]low.risk.males <-16*post$K+1:post$K ## K=unique times including censoringlow.risk.females <-22*post$K+1:post$K
plot(post$times, pred$surv.test.mean[low.risk.males], type='s', col='blue', main='Age 50 with KPS=90', xlab='t', ylab='S(t)', ylim=c(0,1))points(post$times, pred$surv.test.mean[low.risk.females], type='s', col='red')## End(Not run)