######################Tests based on composite likelihood and latent Gaussian models 
##Ref:  Han F, Pan W (2011). A Composite Likelihood Approach to Latent Multivariate Gaussian Modeling of SNP Data with Application to Genetic Association Testing. To appear in Biometrics.
######################modified from the program of Fang HAN (Email: fhan@jhsph.edu) on 3/08/11

library("mvtnorm")
alpha=0.05

# Introduction to input parameters:

# Y: disease lables; =0 for controls, =1 for cases;
# X: genotype data; each row for a subject, and each column for an SNP;
# B: the number of permutations; 

############################
# Output is the p-values of the eight tests in the order of:                 
# 1-4. L:Latent, Wald-type: Wald, SSB, SSBw, UminP
# 5-8. L:Latent, CLRT or LRT: CLRT-P, CLRT-S, LRT-P, LRT-S




############################################
#######Some back-up programs
############################################
PowerUniv <- function(U,V){
      n <- dim(V)[1]

      x <- as.numeric(max(abs(U)))
      TER <- as.numeric(1-pmvnorm(lower=c(rep(-x,n)),upper=c(rep(x,n)),mean=c(rep(0,n)),sigma=V))
      
      return(TER)
}


pvTest <- function(U,V){
     CovS <- positiveR(V)
     

     #SumSqU:
     Tg1<- t(U) %*% U
     ##distr of Tg1 is sum of cr Chisq_1:
       cr<-eigen(CovS, only.values=TRUE)$values
     ##approximate the distri by alpha Chisq_d + beta:
       alpha1<-sum(cr*cr*cr)/sum(cr*cr)
     beta1<-sum(cr) - (sum(cr*cr)^2)/(sum(cr*cr*cr))
     d1<-(sum(cr*cr)^3)/(sum(cr*cr*cr)^2)
     alpha1<-as.real(alpha1)
     beta1<-as.real(beta1)
     d1<-as.real(d1)
     pTg1<-as.numeric(1-pchisq((Tg1-beta1)/alpha1, d1))


   #SumSqUw:
     Tg2<- t(U) %*%  diag(1/diag(CovS)) %*% U
     ##distr of Tg1 is sum of cr Chisq_1:
       cr<-eigen(CovS %*% diag(1/diag(CovS)), only.values=TRUE)$values
     ##approximate the distri by alpha Chisq_d + beta:
       alpha2<-sum(cr*cr*cr)/sum(cr*cr)
       beta2<-sum(cr) - (sum(cr*cr)^2)/(sum(cr*cr*cr))
       d2<-(sum(cr*cr)^3)/(sum(cr*cr*cr)^2)
       alpha2<-as.real(alpha2)
       beta2<-as.real(beta2)
       d2<-as.real(d2)
     pTg2<-as.numeric(1-pchisq((Tg2-beta2)/alpha2, d2))


  ##########score test:
    ##gInv of CovS:
      CovS.edecomp<-eigen(CovS)
      CovS.rank<-sum(abs(CovS.edecomp$values)> 1e-8)
      inveigen<-ifelse(abs(CovS.edecomp$values) >1e-8, 1/CovS.edecomp$values, 0)
      P<-solve(CovS.edecomp$vectors)
      gInv.CovS<-t(P) %*% diag(inveigen) %*% P
    Tscore<- t(U) %*% gInv.CovS  %*% U
    pTscore<-as.numeric( 1-pchisq(Tscore, CovS.rank) )

  #### V1 %*% t(V1)=CovS:
    CovS.ev<-ifelse(abs(CovS.edecomp$values) >1e-8, CovS.edecomp$values, 0)
    V1<-CovS.edecomp$vectors %*% diag(sqrt(abs(CovS.ev)))

  ##univariate/marginal tests:
    nSNP<-length(U)
    Tus<-as.vector(abs(U)/sqrt(diag(CovS)))
    Vs <- matrix(c(rep(0,nSNP^2)),nrow=nSNP)
    for(i in 1:nSNP){
      for(j in 1:nSNP){
          Vs[i,j] <- CovS[i,j]/sqrt(CovS[i,i]*CovS[j,j])
        }
      }
    pTus <- as.numeric(PowerUniv(Tus,Vs))

  return(cbind(pTscore,pTg1,pTg2,pTus))
}


corrmatrix<-function(rho){
    M <- matrix(c(1,rho,rho,1),nrow=2)
    return(M)
}

getLU <- function(a,index,z0,z1){
   p <- length(a)
   L <- U <- NULL
   for(m in 1:p){
     if(a[m]==0) {L=c(L,-Inf);U=c(U,z0[index[m]])} 
     if(a[m]==1) {L=c(L,z0[index[m]]);U=c(U,z1[index[m]])}
     if(a[m]==2) {L=c(L,z1[index[m]]);U=c(U,Inf)}  
   }
   return(cbind(L,U))   
}

controlInf <- function(x){
   p <- length(x)
   for(i in 1:p){
     if(x[i]==Inf) x[i]<-100
     if(x[i]==-Inf) x[i]<--100   
   }
   return(x)  
}

standadizeM<-function(M){
   p<-dim(M)[1]
   sM <- M
   for(i in 1:p)
     for(j in i:p)
        sM[i,j]=sM[j,i]=M[i,j]/sqrt(M[i,i]*M[j,j]) 
   return(sM)
}

positiveM <- function(CORR){
     CORR.edecomp<-eigen(CORR)

     CORR.ev<-ifelse(CORR.edecomp$values >1e-3, CORR.edecomp$values, 1e-3)
     adjustCORR<-standadizeM(CORR.edecomp$vectors %*% diag(CORR.ev) %*% t(CORR.edecomp$vectors))
 
     sum <- sum(CORR.edecomp$values < 1e-3)

     return(cbind(sum, adjustCORR))
}

positiveR <- function(CORR){
     CORR.edecomp<-eigen(CORR)

     CORR.ev<-ifelse(CORR.edecomp$values >1e-3, CORR.edecomp$values, 1e-3)
     adjustCORR<-CORR.edecomp$vectors %*% diag(CORR.ev) %*% t(CORR.edecomp$vectors)

     return(adjustCORR)
}


parchisq <- function(x){
   mu <- mean(x)
   sigma2 <- var(x)
   tau2 <- mean(x^3)-mu^3

   a = (tau2-3*mu*sigma2)/4/sigma2
   d = sigma2/2/a^2
   b = mu-a*d

   return(c(a,b,d))
}

eCov <- function(X){
   nSNP <- dim(X)[2]
   nSample <- dim(X)[1]

   a <- b <- rep(0,nSNP)
   for(i in 1:nSNP){
      a[i] <- sum(X[,i]==0)/nSample
      b[i] <- sum(X[,i]==2)/nSample
   }
   z0 <- qnorm(a)
   z1 <- qnorm(1-b)  

   CORR <- matrix(1,nrow=nSNP,ncol=nSNP) 

   Lstat <- 0
   for(i in 1:(nSNP-1)){
      for(j in (i+1):nSNP){
         logLiklih <- function(r){
             logL=0
             index <- c(i,j)
             LLL <- matrix(0,nrow=3,ncol=3) 
             for(i1 in 0:2){
                  for(j1 in 0:2){
                      a <- c(i1,j1)
                      LU <- getLU(a,index,z0,z1)
                      L <- LU[,1]
                      U <- LU[,2]
                      LLL[a[1]+1,a[2]+1]=log(abs(pmvnorm(lower=L,upper=U,corr=corrmatrix(r))))                                                
                  }
             }
            for(s in 1:nSample){
                a <- c(X[s,i],X[s,j])
                logL = logL+LLL[a[1]+1,a[2]+1]
            }
            return(-logL)
         }
      r = nlminb(0,logLiklih,lower=-1,upper=1)$par     
      CORR[i,j]=CORR[j,i]=r
      Lstat = Lstat + logLiklih(r)
      }
   }
   return(cbind(-Lstat,z0,z1,CORR))
}

estP <- function(X){
   n <- dim(X)[1]
   p <- dim(X)[2]
   eCovb <- eCov(X)
   Lstat <- eCovb[1,1]
   CORR <- eCovb[,-c(1,2,3)]
   z0 <- eCovb[,2]
   z1 <- eCovb[,3]

   pM <- positiveM(CORR)
   adjustCORR <- pM[,-1]
   sum <- pM[1,1]

   logLiklih=0
   for(s in 1:n){
     a <- X[s,]
     index <- 1:p
     LU <- getLU(a,index,z0,z1)
     L <- LU[,1]
     U <- LU[,2]
     logLiklih = logLiklih+log(pmvnorm(lower=L,upper=U, corr=adjustCORR))
   }

   cCORR <- NULL
   for(i in 1:(p-1))
      for(j in (i+1):p)
         cCORR <- c(cCORR,adjustCORR[i,j])

   z0 <- controlInf(z0)
   z1 <- controlInf(z1)      

   return(c(Lstat,logLiklih,z0,z1,cCORR))
}

LV <- function(Y,X){
   Case <- X[Y==1,]
   Control <- X[Y==0,]

   caseP <- estP(Case)
   controlP <- estP(Control)
   wholeP <- estP(X)

   ### LV1
   LV1case <- caseP[-c(1,2)]
   LV1control <- controlP[-c(1,2)]
   LV1stat <- LV1case-LV1control

   ### LV2     
   LV2case <- caseP[2]
   LV2control <- controlP[2]
   LV2whole <- wholeP[2]
   LV2stat <- as.numeric(-2*(LV2whole-(LV2case+LV2control)))  

   ### LV3
   LV3case <- caseP[1]
   LV3control <- controlP[1]
   LV3whole <- wholeP[1]
   LV3stat <- as.numeric(-2*(LV3whole-(LV3case+LV3control)))        

   return(c(LV2stat,LV3stat,LV1stat))
}

H0LV <- function(Y,X,B){
    H0LV1<-H0LV2<-H0LV3<-NULL
    n<-length(Y)/2
    for(sim in 1:B){
       set.seed(sim)
       case <- sample(1:(2*n),n,replace = FALSE)
       Y[case] <- 1
       Y[-case] <- 0
       H0stat <- LV(Y,X)
       H0LV1 <- rbind(H0LV1,H0stat[-c(1,2)])
       H0LV2 <- c(H0LV2,H0stat[1])
       H0LV3 <- c(H0LV3,H0stat[2])             
    }
    return(cbind(H0LV2,H0LV3,H0LV1))   
}


############################################
#############  The main program  ###########
############################################

LVMpv <- function(Y,X,B){
   p <- dim(X)[2]
   f = p*(p-1)/2+2*p

   LVstat <- LV(Y,X)
   H0LVstat <- H0LV(Y,X,B)

   ##LV1
   LV1 <- LVstat[-c(1,2)]
   H0LV1 <- H0LVstat[,-c(1,2)]
   V <- cov(H0LV1)
   pvLV1 <- pvTest(LV1,V)

   ##LV2
   LV2 <- LVstat[1]
   H0LV2 <- H0LVstat[,1]
   LV2N <- as.numeric(1-pchisq(LV2,f))
   LV2P <- sum(LV2<H0LV2)/B
   parLV2 <- parchisq(H0LV2)
   a1 <- parLV2[1]
   b1 <- parLV2[2]
   d1 <- parLV2[3]
   LV2A <- as.numeric(1-pchisq((LV2-b1)/a1, d1))   
   pvLV2 <- c(LV2N,LV2P,LV2A)

   ##LV3
   LV3 <- LVstat[2]
   H0LV3 <- H0LVstat[,2]
   LV3N <- as.numeric(1-pchisq(LV3,f))
   LV3P <- sum(LV3<H0LV3)/B
   parLV3 <- parchisq(H0LV3)
   a1 <- parLV3[1]
   b1 <- parLV3[2]
   d1 <- parLV3[3]
   LV3A <- as.numeric(1-pchisq((LV3-b1)/a1, d1))  
   pvLV3 <- c(LV3N,LV3P,LV3A)

   return(c(pvLV1,pvLV3[-1],pvLV2[-1]))
}