Classify matrix-valued data based on a matrix-normal linear discriminant; an object of class "MN".
Classify matrix-valued data based on a matrix-normal linear discriminant; an object of class "MN".
A function for prediction based on an object of class "MN"; models fit by MatLDA or MN_MLE.
PredictMN(object, newdata, newclass =NULL)
Arguments
object: An object of class "MN"; output from MatLDA or MN_MLE.
newdata: New data to be classified; an r×c×Ntest array.
newclass: Class labels for new data; should be in{1,…,J}. Default is NULL.
Returns
pred.class: An Ntest-vector of predicted class membership based on the inputed object.
misclass: If newclass is non-NULL, returns the misclassification proportion on the test data set.
prob.mat: A Ntest×J matrix with the value of discriminant evaluated at each test data point.
References
Molstad, A. J., and Rothman, A. J. (2018). A penalized likelihood method for classification with matrix-valued predictors. Journal of Computational and Graphical Statistics.
Examples
## Generate realizations of matrix-normal random variables## set sample size, dimensionality, number of classes, ## and marginal class probabilitiesN =75N.test =150N.total = N + N.test
r =16p =16C =3pi.list = rep(1/C, C)## create class meansM.array = array(0, dim=c(r, p, C))M.array[3:4,3:4,1]=1M.array[5:6,5:6,2]=.5M.array[3:4,3:4,3]=-2M.array[5:6,5:6,3]=-.5## create covariance matrices U and VUinv = matrix(0, nrow=r, ncol=r)for(i in1:r){for(j in1:r){ Uinv[i,j]=.5^abs(i-j)}}eoU = eigen(Uinv)Uinv.sqrt = tcrossprod(tcrossprod(eoU$vec, diag(eoU$val^(1/2))),eoU$vec)Vinv = matrix(.5, nrow=p, ncol=p)diag(Vinv)=1eoV = eigen(Vinv)Vinv.sqrt = tcrossprod(tcrossprod(eoV$vec, diag(eoV$val^(1/2))),eoV$vec)## generate N.total realizations of matrix-variate normal dataset.seed(1)dat.array = array(0, dim=c(r,p,N.total))class.total = numeric(length=N.total)for(jj in1:N.total){ class.total[jj]= sample(1:C,1, prob=pi.list) dat.array[,,jj]= tcrossprod(crossprod(Uinv.sqrt, matrix(rnorm(r*p), nrow=r)), Vinv.sqrt)+ M.array[,,class.total[jj]]}## store generated data X = dat.array[,,1:N]X.test = dat.array[,,(N+1):N.total]class = class.total[1:N]class.test = class.total[(N+1):N.total]## fit matrix-normal model using maximum likelihoodout = MN_MLE(X = X, class = class)## use output to classify test setcheck = PredictMN(out, newdata = X.test, newclass = class.test)## print misclassification proportioncheck$misclass