prunning.R 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. trim_tree_adaptive_top_down_v2 = function( tree, wilcox_p_thresh, mean_diff_thresh )
  2. {
  3. # leaves = get_leaves(tree)
  4. if(igraph::vcount(tree)==1) return(tree)
  5. # cat('I am in trim_tree_adaptive_top_down_v2\n')
  6. nsig_nodes = union( igraph::V(tree)[which(igraph::V(tree)$wilcox_p > wilcox_p_thresh)]$name, igraph::V(tree)[which(igraph::V(tree)$mean_diff > mean_diff_thresh)]$name )
  7. children_of_nsig = names(unlist(igraph::ego(tree, order=1, node=nsig_nodes, mode='out', mindist=1)))
  8. if(length(children_of_nsig)!=0) trimmed_tree = tree - children_of_nsig
  9. if(length(children_of_nsig)==0) trimmed_tree = tree
  10. comps = igraph::decompose(trimmed_tree)
  11. root_index = which(sapply( comps, function(comp) igraph::V(tree)[1]$name %in% igraph::V(comp)$name )==1)
  12. trimmed_tree = comps[[root_index]]
  13. if(!is_binary_tree( trimmed_tree )) stop("trim_tree_adaptive_top_down, the resulted tree is not a binary tree")
  14. return( trimmed_tree )
  15. }
  16. prunning = function(branches, p0, to_correct=FALSE, width_thresh=-Inf, width_thresh_CD=5, boundary_signal_thresh=-Inf, return_which='TADs', top_down=FALSE, all_levels=FALSE, CD_border_adj=FALSE, peak_thresh=NULL, mean_diff_thresh)
  17. {
  18. size2correct = sum(sapply(branches, igraph::vcount)) - sum(sapply(branches, function(v) length(get_leaves(v))))
  19. p_thresh = p0/size2correct
  20. if(to_correct==FALSE) p_thresh=p0
  21. # if(!is.null(p0)) p_thresh = p0
  22. if(top_down==FALSE)
  23. {
  24. trimmed_branches = lapply( branches, trim_tree_adaptive, max_imp_p=p_thresh, max_nimp_p=Inf, width_thresh=width_thresh, boundary_signal_thresh=boundary_signal_thresh, peak_thresh=peak_thresh )
  25. # size2correct = sum(sapply(trimmed_branches, igraph::vcount)) - sum(sapply(trimmed_branches, function(v) length(get_leaves(v))))
  26. # for( i in 1:length( trimmed_branches ) )
  27. # {
  28. # trimmed_branch = trimmed_branches[[i]]
  29. # if(igraph::vcount(trimmed_branch) > 1) trimmed_branches[[i]] = lapply( trimmed_branches[i], trim_tree_adaptive, max_imp_p=p_thresh, max_nimp_p=Inf, width_thresh=width_thresh, boundary_signal_thresh=-1 )[[1]]
  30. # }
  31. }
  32. # if(top_down==TRUE) trimmed_branches = lapply( branches, trim_tree_adaptive_top_down, max_imp_p=p_thresh, max_nimp_p=Inf, width_thresh=width_thresh, boundary_signal_thresh=boundary_signal_thresh )
  33. if(top_down==TRUE) trimmed_branches = lapply( branches, function(branch) trim_tree_adaptive_top_down_v2(wilcox_p_thresh=p_thresh, mean_diff_thresh=mean_diff_thresh, tree=branch ))
  34. if(CD_border_adj==TRUE)
  35. {
  36. all_tads = get_adjusted_nested_TADs( trimmed_branches, width_thresh_CD, all_levels )
  37. return( all_tads )
  38. }
  39. ## get all nested TADs in trimmed_branches
  40. if(all_levels==TRUE)
  41. {
  42. all_tads = data.frame(start_pos=numeric(), end_pos=numeric())
  43. widths = c(0, sapply(trimmed_branches, function(v) igraph::V(v)[1]$width))
  44. for(i in 1:length(trimmed_branches))
  45. {
  46. all_tads_i = get_all_tads_in_a_trimmed_branch(trimmed_branches[[i]], pos_shift=sum(widths[1:i]))
  47. all_tads = rbind(all_tads, all_tads_i)
  48. }
  49. return( all_tads )
  50. }
  51. if( return_which=='trimmed_branches' ) return( trimmed_branches )
  52. tad_sizes_ind = lapply( trimmed_branches, function(v) get_leaves(v, 'igraph')$width )
  53. tad_sizes = unlist(tad_sizes_ind)
  54. # tads = split(1:sum(tad_sizes), rep(seq_along(tad_sizes), tad_sizes))
  55. end_pos = cumsum(tad_sizes)
  56. start_pos = c(1, 1 + end_pos[-length(end_pos)])
  57. tads = data.frame(start_pos=start_pos, end_pos=end_pos)
  58. return( tads )
  59. }
  60. ## This function combines prunning with branches of only one node
  61. prunning_hybrid <- function(branches, ...)
  62. {
  63. names(branches) = as.character(1:length(branches))
  64. normal_branches = branches[sapply( branches, function(v) class(v)=='igraph' )]
  65. unnormal_branches = branches[sapply( branches, function(v) class(v)!='igraph' )] ## that is reprsented as bin_start:bin_end
  66. trimmed_branches = prunning(normal_branches, return_which='trimmed_branches', ...)
  67. normal_tad_sizes_ind = lapply( trimmed_branches, function(v) get_leaves(v, 'igraph')$width )
  68. unormal_tad_sizes_ind = unnormal_branches
  69. tad_sizes_ind = c(normal_tad_sizes_ind, unormal_tad_sizes_ind)
  70. tad_sizes_ind = tad_sizes_ind[names(branches)]
  71. tad_sizes = unlist(tad_sizes_ind)
  72. # tads = split(1:sum(tad_sizes), rep(seq_along(tad_sizes), tad_sizes))
  73. end_pos = cumsum(tad_sizes)
  74. start_pos = c(1, 1 + end_pos[-length(end_pos)])
  75. tads = data.frame(start_pos=start_pos, end_pos=end_pos)
  76. return( tads )
  77. }
  78. get_all_tads_in_a_trimmed_branch <- function(trimmed_branch, pos_shift)
  79. {
  80. res = data.frame( start_pos=igraph::V(trimmed_branch)$left + pos_shift, end_pos=igraph::V(trimmed_branch)$right + pos_shift )
  81. res = res[order(res[,1], res[,2]), ]
  82. return(res)
  83. }
  84. prunning_bottom_up <- function(branches, p0=NULL, width_thresh)
  85. {
  86. size2correct = sum(sapply(branches, igraph::vcount)) - sum(sapply(branches, function(v) length(get_leaves(v))))
  87. p_thresh = 0.05/size2correct
  88. if(!is.null(p0)) p_thresh = p0
  89. trimmed_branches = lapply( branches, trim_tree_adaptive, max_imp_p=p_thresh, max_nimp_p=Inf, width_thresh=width_thresh )
  90. tad_sizes_ind = lapply( trimmed_branches, function(v) get_leaves(v, 'igraph')$width )
  91. tad_sizes = unlist(tad_sizes_ind)
  92. # tads = split(1:sum(tad_sizes), rep(seq_along(tad_sizes), tad_sizes))
  93. end_pos = cumsum(tad_sizes)
  94. start_pos = c(1, 1 + end_pos[-length(end_pos)])
  95. tads = data.frame(start_pos=start_pos, end_pos=end_pos)
  96. return( tads )
  97. }
  98. trim_tree_adaptive_bottom_up <- function( tree, which_p='imp_p' )
  99. {
  100. if(which_p=='imp_p') ps = sort(unique(igraph::V(tree)$imp_p), decreasing=TRUE)
  101. # if(which_p=='nimp_p') ps = sort(unique(igraph::V(tree)$nimp_p), decreasing=TRUE)
  102. # if(which_p=='both') ps = sort(unique(pmin(igraph::V(tree)$nimp_p, igraph::V(tree)$imp_p)), decreasing=TRUE)
  103. trimed_tree_current = tree
  104. trimmed_branch_bottom_up = vector('list', length(ps))
  105. for(i in 1:length(ps))
  106. {
  107. trimed_tree_current = trim_tree_adaptive( tree, L_diff_thresh=-Inf, max_imp_p=ps[i], max_nimp_p=Inf, width_thresh=-Inf )
  108. trimmed_branch_bottom_up[[i]] = trimed_tree_current
  109. }
  110. igraph::vcounts = sapply(trimmed_branch_bottom_up, igraph::vcount)
  111. ps = ps[!duplicated(igraph::vcounts)]
  112. trimmed_branch_bottom_up = trimmed_branch_bottom_up[!duplicated(igraph::vcounts)]
  113. res = list(ps=ps, trimmed_branch_bottom_up=trimmed_branch_bottom_up)
  114. return( res )
  115. }
  116. ## get adjusted nested TADs
  117. get_adjusted_nested_TADs <- function( trimmed_branches, width_thresh_CD, all_levels )
  118. {
  119. widths = c(0, sapply(trimmed_branches, function(v) igraph::V(v)[1]$width))
  120. all_tads_i_list = lapply( 1:length(trimmed_branches), function(i) get_all_tads_in_a_trimmed_branch(trimmed_branches[[i]], pos_shift=sum(widths[1:i])))
  121. for(i in 1:length(trimmed_branches))
  122. {
  123. all_tads_i = all_tads_i_list[[i]]
  124. if( nrow(all_tads_i) <= 1 ) next
  125. ## move the left-most border a little bit right if needed
  126. left_borders = unique(all_tads_i[,1])
  127. min_diff_left = left_borders[2] - left_borders[1]
  128. if( min_diff_left <= width_thresh_CD )
  129. {
  130. all_tads_i[ all_tads_i==left_borders[1] ] = left_borders[2]
  131. all_tads_i = all_tads_i[ all_tads_i[,2] > all_tads_i[,1], ] ## remove "negative" TADs
  132. all_tads_i = unique(all_tads_i[order(all_tads_i[,1], all_tads_i[,2]), ]) ## reorder the TADs
  133. all_tads_i_list[[i]] = all_tads_i
  134. ## need to modify the right border of nested TADs in previous CD if the left border of this CD is modified
  135. if(i > 1)
  136. {
  137. ## replace the max value of [i-1], i.e., the right most border, as the min of [i]-1, i.e., the left most border of [i]
  138. all_tads_i_list[[i-1]][ all_tads_i_list[[i-1]]==max(all_tads_i_list[[i-1]]) ] = min(all_tads_i_list[[i]]) - 1
  139. }
  140. }
  141. if( nrow(all_tads_i) <= 1 ) next
  142. ## move the right-most border a little bit left if needed
  143. right_borders = unique(rev(all_tads_i[,2]))
  144. min_diff_right = right_borders[1] - right_borders[2]
  145. if( min_diff_right <= width_thresh_CD )
  146. {
  147. all_tads_i[ all_tads_i==right_borders[1] ] = right_borders[2]
  148. all_tads_i = all_tads_i[ all_tads_i[,2] > all_tads_i[,1], ]
  149. all_tads_i = unique(all_tads_i[order(all_tads_i[,1], all_tads_i[,2]), ]) ## reorder the TADs
  150. all_tads_i_list[[i]] = all_tads_i
  151. if(i < length(trimmed_branches))
  152. {
  153. ## replace the max value of [i-1], i.e., the right most border, as the min of [i]-1, i.e., the left most border of [i]
  154. all_tads_i_list[[i+1]][ all_tads_i_list[[i+1]]==min(all_tads_i_list[[i+1]]) ] = max(all_tads_i_list[[i]]) + 1
  155. }
  156. }
  157. }
  158. if(!all_levels) all_tads_i_list = lapply( all_tads_i_list, function(v) data.frame(start_pos=head(v[,1],1), end_pos=tail(v[,2],1)) )
  159. all_tads = do.call(rbind, all_tads_i_list)
  160. colnames(all_tads) = c('start_pos', 'end_pos')
  161. return( all_tads )
  162. }