From 7e8cb35e1bef7eb4045f331d5a46c8abbc8eba4c Mon Sep 17 00:00:00 2001 From: Alice MacQueen Date: Thu, 1 Apr 2021 16:24:45 -0500 Subject: [PATCH] Change calculation of mash posterior matrices to batch process SNPs and speed computation --- R/wrapper.R | 107 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 19 deletions(-) diff --git a/R/wrapper.R b/R/wrapper.R index 176266f..2de5152 100644 --- a/R/wrapper.R +++ b/R/wrapper.R @@ -310,30 +310,20 @@ dive_phe2mash <- function(df, snp, type = "linear", svd = NULL, suffix = "", Bhat_random <- as.matrix(gwas2[random_sample$value, ind_estim]) Shat_random <- as.matrix(gwas2[random_sample$value, ind_se]) - ## Full data: Both Bhat and Shat are zero (or near zero) for some input data. - ## Filter this data from the input, or set Shat to a positive number to - ## avoid numerical issues. which rowSums are 0, filter these out or make +. - ## Eventually want to batch process SNPs through this, not make a full set. - Bhat_full <- as.matrix(gwas2[, ind_estim]) - Shat_full <- as.matrix(gwas2[, ind_se]) - ## name the columns for these conditions (usually the phenotype) colnames(Bhat_strong) <- gwas_metadata$phe colnames(Shat_strong) <- gwas_metadata$phe colnames(Bhat_random) <- gwas_metadata$phe colnames(Shat_random) <- gwas_metadata$phe - colnames(Bhat_full) <- gwas_metadata$phe - colnames(Shat_full) <- gwas_metadata$phe # 5. mash ---- data_r <- mashr::mash_set_data(Bhat_random, Shat_random) - printf2(verbose = verbose, "\nEstimating correlation structure in the null tests from a random sample of clumped data.") + printf2(verbose = verbose, "\nEstimating correlation structure in the null tests from a random sample of clumped data.\n") Vhat <- mashr::estimate_null_correlation_simple(data = data_r) data_strong <- mashr::mash_set_data(Bhat_strong, Shat_strong, V = Vhat) data_random <- mashr::mash_set_data(Bhat_random, Shat_random, V = Vhat) - data_full <- mashr::mash_set_data(Bhat_full, Shat_full, V = Vhat) U_c <- mashr::cov_canonical(data_random) if (is.na(U.ed[1])) { @@ -365,10 +355,79 @@ dive_phe2mash <- function(df, snp, type = "linear", svd = NULL, suffix = "", } else { m = mashr::mash(data_random, Ulist = c(U_ed, U_c), outputlevel = 1) printf2(verbose = verbose, "\nNo user-specified covariance matrices were included in the mash fit.") - } + } + printf2(verbose = verbose, "\nComputing posterior weights for all effects using the mash fit from the random tests.") - m2 = mashr::mash(data_full, g = ashr::get_fitted_g(m), fixg = TRUE) + ## Batch process SNPs through this, don't run on full set if > 20000 rows. + ## Even for 1M SNPs, because computing posterior weights scales quadratically + ## with the number of rows in Bhat and Shat. 10K = 13s, 20K = 55s; 40K = 218s + ## By my calc, this starts getting unwieldy between 4000 and 8000 rows. + ## See mash issue: https://github.com/stephenslab/mashr/issues/87 + if(gwas2$nrow > 20000){ + subset_size <- 4000 + n_subsets <- ceiling(gwas2$nrow / subset_size) + printf2(verbose = verbose, "\nSplitting data into %s sets of 4K markers to speed computation.\n", + n_subsets) + for (i in 1:n_subsets) { + if(i < n_subsets){ + from <- (i*subset_size - (subset_size - 1)) + to <- i*subset_size + row_subset <- from:to + } else { + from <- n_subsets*subset_size - (subset_size - 1) + to <- gwas2$nrow + row_subset <- from:to + } + Bhat_subset <- as.matrix(gwas2[row_subset, ind_estim]) + Shat_subset <- as.matrix(gwas2[row_subset, ind_se]) + colnames(Bhat_subset) <- gwas_metadata$phe + colnames(Shat_subset) <- gwas_metadata$phe + data_subset <- mashr::mash_set_data(Bhat_subset, Shat_subset, V = Vhat) + m_subset = mashr::mash(data_subset, g = ashr::get_fitted_g(m), fixg = TRUE) + + if (i == 1){ + m2 <- m_subset + } else { # make a new mash object with the combined data. + PosteriorMean = rbind(m2$result$PosteriorMean, m_subset$result$PosteriorMean) + PosteriorSD = rbind(m2$result$PosteriorSD, m_subset$result$PosteriorSD) + lfdr = rbind(m2$result$lfdr, m_subset$result$lfdr) + NegativeProb = rbind(m2$result$NegativeProb, m_subset$result$NegativeProb) + lfsr = rbind(m2$result$lfsr, m_subset$result$lfsr) + posterior_matrices = list(PosteriorMean = PosteriorMean, + PosteriorSD = PosteriorSD, + lfdr = lfdr, + NegativeProb = NegativeProb, + lfsr = lfsr) + loglik = m2$loglik # NB must recalculate from sum(vloglik) at end + vloglik = rbind(m2$vloglik, m_subset$vloglik) + null_loglik = c(m2$null_loglik, m_subset$null_loglik) + alt_loglik = rbind(m2$alt_loglik, m_subset$alt_loglik) + fitted_g = m2$fitted_g # all four components are equal + posterior_weights = rbind(m2$posterior_weights, m_subset$posterior_weights) + alpha = m2$alpha # equal + m2 = list(result = posterior_matrices, + loglik = loglik, + vloglik = vloglik, + null_loglik = null_loglik, + alt_loglik = alt_loglik, + fitted_g = fitted_g, + posterior_weights = posterior_weights, + alpha = alpha) + class(m2) = "mash" + } + } + loglik = sum(m2$vloglik) + m2$loglik <- loglik + # total loglik in mash function is: loglik = sum(vloglik) + } else { + Bhat_full <- as.matrix(gwas2[, ind_estim]) + Shat_full <- as.matrix(gwas2[, ind_se]) + colnames(Bhat_full) <- gwas_metadata$phe + colnames(Shat_full) <- gwas_metadata$phe + data_full <- mashr::mash_set_data(Bhat_full, Shat_full, V = Vhat) + m2 = mashr::mash(data_full, g = ashr::get_fitted_g(m), fixg = TRUE) + } return(m2) } @@ -454,7 +513,11 @@ div_gwas <- function(df, snp, type, svd, npcs){ return(gwaspc) } -#' Verbose? +#' @title print message function if verbose +#' +#' @param verbose Logical. If TRUE, print progress messages. +#' @param ... Other arguments to `printf()` +#' #' @importFrom bigassertr printf printf2 <- function(verbose, ...) if (verbose) { printf(...) } @@ -674,7 +737,6 @@ div_lambda_GC <- function(df, type = c("linear", "logistic"), snp, " PC #'s to test (npcs).")) } - G <- snp$genotypes LambdaGC <- as_tibble(matrix(data = @@ -850,9 +912,15 @@ check_gwas <- function(df1, phename, type, nPhe, minphe, nLev){ return(gwas_ok) } +#' @title A wrapper function to `stop` call +#' +#' @param x input matrix +#' @param msg Character string. A message to append to the stop call. +labelled_stop = function(x, msg) + stop(paste(gsub("\\s+", " ", paste0(deparse(x))), msg), call.=F) -## @title Basic sanity check for covariance matrices -## @param X input matrix +#' @title Basic sanity check for covariance matrices +#' @param x input matrix check_covmat_basics = function(x) { label = substitute(x) if (!is.matrix(x)) @@ -872,8 +940,9 @@ check_covmat_basics = function(x) { return(TRUE) } -## @title check matrix for positive definitness -## @param X input matrix +#' @title check matrix for positive definitness +#' +#' @param x input matrix check_positive_definite = function(x) { check_covmat_basics(x) tryCatch(chol(x),