bisecting_kmeans.R 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. ## k-means with replicatable seeds
  2. my_kmeans = function(iter.max=1E3, nstart=50, ...)
  3. {
  4. set.seed(1)
  5. res = kmeans(iter.max=iter.max, nstart=nstart, ...)
  6. return(res)
  7. }
  8. ## This function tries to adjust the height of each split, in order to generate a valid hclust object and with balanced compartments A.1 A.2 B.1 B.2
  9. ## Clusters with more nodes will get bigger height in case of same height
  10. adjust_hs <- function(l_r_h)
  11. {
  12. hs = sapply(l_r_h, function(v) v$h)
  13. all_names = sapply(l_r_h, function(v) paste0(collapse='_', sort(c(v$l, v$r))))
  14. r_names = sapply(l_r_h, function(v) paste0(collapse='_', sort(c(v$r))))
  15. sizes = sapply(l_r_h, function(v) length(v$l) + length(v$r)) ##
  16. ################ This part deals with duplicated heights
  17. hs = hs + sizes*1E-7
  18. ################ This part tries to make the top-level left and right branch to have similar height, such that to make balanced A.1, A.2, B.1, B.2 compartments
  19. ## Find the index of second branch, whose number of nodes is n_total - n_left: sizes[1] - sizes[2]
  20. l_b = 2 ## left sub-branch
  21. # r_b = which(sizes==(sizes[1] - sizes[2]))[1] ## right sub-branch
  22. r_b = which(r_names[1]==all_names) ## right sub-branch
  23. l_h = hs[l_b]
  24. r_h = hs[r_b]
  25. max_h = max(l_h, r_h) ## the maximum height of the two branches
  26. hs_new = mean(sort(hs, decreasing=TRUE)[2:3]) ## hs_new is the 3rd largest height
  27. hs[l_b] = ifelse(l_h > r_h, max_h, hs_new)
  28. hs[r_b] = ifelse(r_h > l_h, max_h, hs_new)
  29. if(any(duplicated(hs))) stop('ERROR: DUPLICATED HEIGHTS exist in bisecting_kmeans')
  30. return( hs )
  31. }
  32. bisecting_kmeans <- function(data)
  33. {
  34. dist_mat = as.matrix(stats::dist(data))
  35. indices = 1:nrow(data)
  36. l_r_h <<- list()
  37. get_h <- function(l_indices, r_indices)
  38. {
  39. combined_indices = c(l_indices, r_indices)
  40. idx <- as.matrix(expand.grid(combined_indices, combined_indices))
  41. max(dist_mat[idx]) ## diameter
  42. }
  43. get_sub_tree <- function( indices )
  44. {
  45. n_nodes = length(indices)
  46. if(n_nodes==1) ## if only two nodes
  47. {
  48. h = NULL
  49. # tree = list(h=h, leaf=indices)
  50. return()
  51. }
  52. ############# if more than two nodes
  53. if(n_nodes==2) cluster=c(1,2) else cluster = my_kmeans(x=data[indices, ], centers=2)$cluster
  54. l_indices = indices[cluster==1]
  55. r_indices = indices[cluster==2]
  56. h = get_h(l_indices, r_indices)
  57. l_r_h <<- c(l_r_h, list(list(l=l_indices, r=r_indices, h=h)))
  58. # cat(h, '\n')
  59. l_branch = get_sub_tree( l_indices )
  60. r_branch = get_sub_tree( r_indices )
  61. # tree = list(h=h, l_branch=l_branch, r_branch=r_branch, l_indices=l_indices, r_indices=r_indices)
  62. # return(tree)
  63. }
  64. get_sub_tree(indices)
  65. hs = adjust_hs(l_r_h)
  66. r_hs = rank(hs)
  67. for( i in 1:length(l_r_h) ) {name=r_hs[i]; names(name)=paste0(collapse='_', sort(c(l_r_h[[i]]$l, l_r_h[[i]]$r))); l_r_h[[i]]$name=name}
  68. pos_names = sapply(l_r_h, function(v) v$name)
  69. neg_names = -(1:length(indices)); names(neg_names) = 1:length(indices); all_names = c(pos_names, neg_names)
  70. for( i in 1:length(l_r_h) ) {l_r_h[[i]]$l_name=unname(all_names[paste0(l_r_h[[i]]$l, collapse='_')]); l_r_h[[i]]$r_name=unname(all_names[paste0(l_r_h[[i]]$r, collapse='_')]) }
  71. merge_height = data.frame(l=sapply(l_r_h, function(v) v$l_name), r=sapply(l_r_h, function(v) v$r_name), h=hs)
  72. merge_height = merge_height[order(merge_height$h), ]
  73. rownames(merge_height) = NULL
  74. data_tmp = cbind(c(0,0,1,1), c(0,1,1,0))
  75. hc = hclust(stats::dist(data_tmp), "com")
  76. hc$merge = as.matrix(unname(merge_height[,1:2]))
  77. hc$height = merge_height$h
  78. # hc$order = unname(unlist(res, recursive=TRUE)[grepl('leaf', names(unlist(res, recursive=TRUE)))])
  79. # hc$order = 1:length(indices)
  80. hc$labels = 1:length(indices)
  81. den <- as.dendrogram(hc)
  82. hc_r <- as.hclust(reorder(den, 1:length(indices)))
  83. hc_r$method = "complete"
  84. hc_r$dist.method = "euclidean"
  85. l_r_h <<- list()
  86. rm(l_r_h)
  87. return(hc_r)
  88. }