Change calculation of mash posterior matrices to batch process SNPs and speed computation

This commit is contained in:
2021-04-01 16:24:45 -05:00
parent b9d2edbad3
commit 7e8cb35e1b

View File

@@ -310,30 +310,20 @@ dive_phe2mash <- function(df, snp, type = "linear", svd = NULL, suffix = "",
Bhat_random <- as.matrix(gwas2[random_sample$value, ind_estim])
Shat_random <- as.matrix(gwas2[random_sample$value, ind_se])
## Full data: Both Bhat and Shat are zero (or near zero) for some input data.
## Filter this data from the input, or set Shat to a positive number to
## avoid numerical issues. which rowSums are 0, filter these out or make +.
## Eventually want to batch process SNPs through this, not make a full set.
Bhat_full <- as.matrix(gwas2[, ind_estim])
Shat_full <- as.matrix(gwas2[, ind_se])
## name the columns for these conditions (usually the phenotype)
colnames(Bhat_strong) <- gwas_metadata$phe
colnames(Shat_strong) <- gwas_metadata$phe
colnames(Bhat_random) <- gwas_metadata$phe
colnames(Shat_random) <- gwas_metadata$phe
colnames(Bhat_full) <- gwas_metadata$phe
colnames(Shat_full) <- gwas_metadata$phe
# 5. mash ----
data_r <- mashr::mash_set_data(Bhat_random, Shat_random)
printf2(verbose = verbose, "\nEstimating correlation structure in the null tests from a random sample of clumped data.")
printf2(verbose = verbose, "\nEstimating correlation structure in the null tests from a random sample of clumped data.\n")
Vhat <- mashr::estimate_null_correlation_simple(data = data_r)
data_strong <- mashr::mash_set_data(Bhat_strong, Shat_strong, V = Vhat)
data_random <- mashr::mash_set_data(Bhat_random, Shat_random, V = Vhat)
data_full <- mashr::mash_set_data(Bhat_full, Shat_full, V = Vhat)
U_c <- mashr::cov_canonical(data_random)
if (is.na(U.ed[1])) {
@@ -366,9 +356,78 @@ dive_phe2mash <- function(df, snp, type = "linear", svd = NULL, suffix = "",
m = mashr::mash(data_random, Ulist = c(U_ed, U_c), outputlevel = 1)
printf2(verbose = verbose, "\nNo user-specified covariance matrices were included in the mash fit.")
}
printf2(verbose = verbose, "\nComputing posterior weights for all effects
using the mash fit from the random tests.")
## Batch process SNPs through this, don't run on full set if > 20000 rows.
## Even for 1M SNPs, because computing posterior weights scales quadratically
## with the number of rows in Bhat and Shat. 10K = 13s, 20K = 55s; 40K = 218s
## By my calc, this starts getting unwieldy between 4000 and 8000 rows.
## See mash issue: https://github.com/stephenslab/mashr/issues/87
if(gwas2$nrow > 20000){
subset_size <- 4000
n_subsets <- ceiling(gwas2$nrow / subset_size)
printf2(verbose = verbose, "\nSplitting data into %s sets of 4K markers to speed computation.\n",
n_subsets)
for (i in 1:n_subsets) {
if(i < n_subsets){
from <- (i*subset_size - (subset_size - 1))
to <- i*subset_size
row_subset <- from:to
} else {
from <- n_subsets*subset_size - (subset_size - 1)
to <- gwas2$nrow
row_subset <- from:to
}
Bhat_subset <- as.matrix(gwas2[row_subset, ind_estim])
Shat_subset <- as.matrix(gwas2[row_subset, ind_se])
colnames(Bhat_subset) <- gwas_metadata$phe
colnames(Shat_subset) <- gwas_metadata$phe
data_subset <- mashr::mash_set_data(Bhat_subset, Shat_subset, V = Vhat)
m_subset = mashr::mash(data_subset, g = ashr::get_fitted_g(m), fixg = TRUE)
if (i == 1){
m2 <- m_subset
} else { # make a new mash object with the combined data.
PosteriorMean = rbind(m2$result$PosteriorMean, m_subset$result$PosteriorMean)
PosteriorSD = rbind(m2$result$PosteriorSD, m_subset$result$PosteriorSD)
lfdr = rbind(m2$result$lfdr, m_subset$result$lfdr)
NegativeProb = rbind(m2$result$NegativeProb, m_subset$result$NegativeProb)
lfsr = rbind(m2$result$lfsr, m_subset$result$lfsr)
posterior_matrices = list(PosteriorMean = PosteriorMean,
PosteriorSD = PosteriorSD,
lfdr = lfdr,
NegativeProb = NegativeProb,
lfsr = lfsr)
loglik = m2$loglik # NB must recalculate from sum(vloglik) at end
vloglik = rbind(m2$vloglik, m_subset$vloglik)
null_loglik = c(m2$null_loglik, m_subset$null_loglik)
alt_loglik = rbind(m2$alt_loglik, m_subset$alt_loglik)
fitted_g = m2$fitted_g # all four components are equal
posterior_weights = rbind(m2$posterior_weights, m_subset$posterior_weights)
alpha = m2$alpha # equal
m2 = list(result = posterior_matrices,
loglik = loglik,
vloglik = vloglik,
null_loglik = null_loglik,
alt_loglik = alt_loglik,
fitted_g = fitted_g,
posterior_weights = posterior_weights,
alpha = alpha)
class(m2) = "mash"
}
}
loglik = sum(m2$vloglik)
m2$loglik <- loglik
# total loglik in mash function is: loglik = sum(vloglik)
} else {
Bhat_full <- as.matrix(gwas2[, ind_estim])
Shat_full <- as.matrix(gwas2[, ind_se])
colnames(Bhat_full) <- gwas_metadata$phe
colnames(Shat_full) <- gwas_metadata$phe
data_full <- mashr::mash_set_data(Bhat_full, Shat_full, V = Vhat)
m2 = mashr::mash(data_full, g = ashr::get_fitted_g(m), fixg = TRUE)
}
return(m2)
}
@@ -454,7 +513,11 @@ div_gwas <- function(df, snp, type, svd, npcs){
return(gwaspc)
}
#' Verbose?
#' @title print message function if verbose
#'
#' @param verbose Logical. If TRUE, print progress messages.
#' @param ... Other arguments to `printf()`
#'
#' @importFrom bigassertr printf
printf2 <- function(verbose, ...) if (verbose) { printf(...) }
@@ -674,7 +737,6 @@ div_lambda_GC <- function(df, type = c("linear", "logistic"), snp,
" PC #'s to test (npcs)."))
}
G <- snp$genotypes
LambdaGC <- as_tibble(matrix(data =
@@ -850,9 +912,15 @@ check_gwas <- function(df1, phename, type, nPhe, minphe, nLev){
return(gwas_ok)
}
#' @title A wrapper function to `stop` call
#'
#' @param x input matrix
#' @param msg Character string. A message to append to the stop call.
labelled_stop = function(x, msg)
stop(paste(gsub("\\s+", " ", paste0(deparse(x))), msg), call.=F)
## @title Basic sanity check for covariance matrices
## @param X input matrix
#' @title Basic sanity check for covariance matrices
#' @param x input matrix
check_covmat_basics = function(x) {
label = substitute(x)
if (!is.matrix(x))
@@ -872,8 +940,9 @@ check_covmat_basics = function(x) {
return(TRUE)
}
## @title check matrix for positive definitness
## @param X input matrix
#' @title check matrix for positive definitness
#'
#' @param x input matrix
check_positive_definite = function(x) {
check_covmat_basics(x)
tryCatch(chol(x),