zigzag_loglik_ancestors_v4.5.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. // -*- mode: C++; c-indent-level: 4; c-basic-offset: 4; indent-tabs-mode: nil; -*-
  2. #include "RcppArmadillo.h"
  3. using namespace Rcpp;
  4. using namespace std;
  5. // #define M_PI 3.141592653589793238462643383280 /* pi */
  6. // [[Rcpp::export]]
  7. double loglik_lnorm_cpp( double sum_ln1, double sum_ln2, double p, double q ) // several times faster than fitdistr
  8. {
  9. if( sum_ln2 < 0 ) cout << "sum_ln2 not valid in loglik_lnorm_cpp\n";
  10. if( p < 0 ) cout << "p not valid in loglik_lnorm_cpp\n";
  11. if( q < 0 ) cout << "q not valid in loglik_lnorm_cpp\n";
  12. if( q <=1 ) return 0; // p = sum(x==0): the number of zero-values
  13. double alpha = 1.0*p/(p+q);
  14. double mu = sum_ln1/q;
  15. double sigma2 = sum_ln2/q + pow(mu, 2) - 2*mu*sum_ln1/q; // %: element-wise product
  16. if( abs(sigma2) <=1E-10 )
  17. {
  18. return 0;
  19. }
  20. double loglik = -sum_ln1 - 0.5*q*(log(2*M_PI) + log(sigma2)) - (sum_ln2 + q*pow(mu, 2) - 2*mu*sum_ln1)/2/sigma2;
  21. if( p==0 ) {return loglik;} else {return p*log(alpha) + q*log(1-alpha) + loglik;}
  22. }
  23. // [[Rcpp::export]]
  24. double loglik_lnorm_cpp_vec( arma::vec vec_values ) //log-likelihood of vec_values
  25. {
  26. int p, q;
  27. int n = vec_values.n_elem;
  28. if( n < 2 ) return 0; // added on 19-07-2018
  29. double sum_ln1, sum_ln2;
  30. arma::vec positive_vec = vec_values.elem(find(vec_values > 0));
  31. q = positive_vec.n_elem; // the number of positive values
  32. p = n - q; // the number of zeros
  33. if( q <= 1 ) return 0;
  34. sum_ln1 = sum(log(positive_vec));
  35. sum_ln2 = sum(pow(log(positive_vec), 2));
  36. return loglik_lnorm_cpp( sum_ln1, sum_ln2, p, q );
  37. }
  38. // [[Rcpp::export]]
  39. arma::mat get_A_len(arma::mat A) // get matrix A_len: A_len := A*(A>0). This is used for computing the number of positive elements in a rectangle region
  40. {
  41. int n_row = A.n_rows;
  42. arma::mat A_len=arma::zeros<arma::mat>(n_row, n_row); // for test
  43. arma::uvec ids = find(A > 0);
  44. arma::vec new_values = arma::ones<arma::vec>(ids.n_elem);
  45. A_len.elem(ids) = new_values;
  46. return A_len;
  47. }
  48. // [[Rcpp::export]]
  49. arma::mat get_A_ln1(arma::mat A) // log(A_ij)
  50. {
  51. int n_row = A.n_rows;
  52. arma::mat A_ln1=arma::zeros<arma::mat>(n_row, n_row); // for test
  53. arma::uvec ids = find(A > 0);
  54. arma::vec new_values = log(A.elem(ids));
  55. A_ln1.elem(ids) = new_values;
  56. return A_ln1;
  57. }
  58. // [[Rcpp::export]]
  59. arma::mat get_A_ln2(arma::mat A) // log(A_ij)^2
  60. {
  61. int n_row = A.n_rows;
  62. arma::mat A_ln2=arma::zeros<arma::mat>(n_row, n_row); // for test
  63. arma::uvec ids = find(A > 0);
  64. arma::vec new_values = pow(log(A.elem(ids)), 2);
  65. A_ln2.elem(ids) = new_values;
  66. return A_ln2;
  67. }
  68. // compute the loglik matrix
  69. // [[Rcpp::export]]
  70. arma::mat loglik_lnorm_cpp_mat( arma::mat sum_ln1, arma::mat sum_ln2, arma::mat ps, arma::mat qs ) // several times faster than fitdistr
  71. {
  72. int n_row = sum_ln1.n_rows;
  73. int n_col = sum_ln1.n_cols;
  74. arma::mat loglik(n_row, n_col);
  75. for(int i=0; i<n_row; i++)
  76. for(int j=0; j<n_col; j++)
  77. loglik(i,j) = loglik_lnorm_cpp( sum_ln1(i,j), sum_ln2(i,j), ps(i,j), qs(i,j) );
  78. return loglik;
  79. }
  80. /////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  81. // compute the blockwise sum
  82. // [[Rcpp::export]] // use the same version name v4 instead of v4_5
  83. List zigzag_loglik_ancestors_v4_5( arma::mat A, int k, int min_n_bins=2 ){
  84. int n_row = A.n_rows;
  85. // int min_n_bins = 2;
  86. int p, q, max_mid_index, max_mid;
  87. arma::mat ps, loglik_tmp, loglik;
  88. arma::vec n_cells;
  89. arma::mat A_len = get_A_len(A); // A_len = A*(A > 0)
  90. arma::mat A_ln1 = get_A_ln1(A); // log(A_ij)
  91. arma::mat A_ln2 = get_A_ln2(A); // log(A_ij)^2
  92. StringMatrix ancestors(n_row, n_row);
  93. arma::mat L=arma::zeros<arma::mat>(n_row, n_row); // for test
  94. Rcpp::List res;
  95. for( int k= min_n_bins-1; k<=(2*min_n_bins - 2); k++ ) // other values of L are 0
  96. {
  97. for( int v=1; v<= (n_row - k); v++ )
  98. {
  99. arma::mat tmp_mat=A.submat(arma::span(v-1, v+k-1), arma::span(v-1, v+k-1)); // span(0,1) := 1:2 in R
  100. arma::vec upper_tri_vec = tmp_mat.elem(find(trimatu(tmp_mat)));
  101. L(v-1, v-1+k) = loglik_lnorm_cpp_vec( upper_tri_vec );
  102. }
  103. } // Checked to be the same in R
  104. // cout << "Finished initialize L\n";
  105. // initialize the rad_mat as the first off-diagonal values
  106. arma::mat rad_mat_current_ln1(n_row-1, 1);
  107. arma::mat rad_mat_current_ln2(n_row-1, 1);
  108. arma::mat rad_mat_current_len(n_row-1, 1);
  109. for(int i=0; i<(n_row-1); i++) // The first off-diagonal values
  110. {
  111. rad_mat_current_ln1(i, 0) = A_ln1(i, i+1);
  112. rad_mat_current_ln2(i, 0) = A_ln2(i, i+1);
  113. rad_mat_current_len(i, 0) = A_len(i, i+1);
  114. }
  115. // initialized to be two vertical cells (2 rows)
  116. arma::mat vertical_columns_next_ln1(2, n_row-2);
  117. arma::mat vertical_columns_next_ln2(2, n_row-2);
  118. arma::mat vertical_columns_next_len(2, n_row-2);
  119. for(int i=0; i<(n_row-2); i++)
  120. {
  121. vertical_columns_next_ln1.col(i) = A_ln1( arma::span(i, i+1), i+2 );
  122. vertical_columns_next_ln2.col(i) = A_ln2( arma::span(i, i+1), i+2 );
  123. vertical_columns_next_len.col(i) = A_len( arma::span(i, i+1), i+2 );
  124. }
  125. for(int i=1; i<2; i++) // cumsum of the two vertical cells (i=1:1)
  126. {
  127. vertical_columns_next_ln1.row(i) = vertical_columns_next_ln1.row(i) + vertical_columns_next_ln1.row(i-1); //cumsum
  128. vertical_columns_next_ln2.row(i) = vertical_columns_next_ln2.row(i) + vertical_columns_next_ln2.row(i-1); //cumsum
  129. vertical_columns_next_len.row(i) = vertical_columns_next_len.row(i) + vertical_columns_next_len.row(i-1); //cumsum
  130. }
  131. arma::mat rad_mat_next_ln1 = rad_mat_current_ln1;
  132. arma::mat rad_mat_next_ln2 = rad_mat_current_ln2;
  133. arma::mat rad_mat_next_len = rad_mat_current_len;
  134. arma::mat vertical_columns_current_ln1 = vertical_columns_next_ln1; // this line just create the vertical_columns_current_ln
  135. arma::mat vertical_columns_current_ln2 = vertical_columns_next_ln2; // this line just create the vertical_columns_current_ln
  136. arma::mat vertical_columns_current_len = vertical_columns_next_len; // this line just create the vertical_columns_current_ln
  137. // Rcout << L << "\n";
  138. // cout << "Begin iteration:\n";
  139. // time complexity of this part: n^3
  140. // for(int shift=3; shift<=n_row; shift++)
  141. // each row of rad_mat represent one off-diagonal point
  142. for(int shift=3; shift<=k; shift++)
  143. {
  144. rad_mat_current_ln1 = rad_mat_next_ln1;
  145. rad_mat_current_ln2 = rad_mat_next_ln2;
  146. rad_mat_current_len = rad_mat_next_len;
  147. rad_mat_next_ln1 = arma::mat( n_row-shift+1, shift-1);
  148. rad_mat_next_ln2 = arma::mat( n_row-shift+1, shift-1);
  149. rad_mat_next_len = arma::mat( n_row-shift+1, shift-1);
  150. n_cells = arma::zeros<arma::vec>(shift-1); // size of each rectangle
  151. for( int i=0; i< (shift-1); i++ ) n_cells(i) = (i+1)*(shift-i-1);
  152. arma::mat rad_mat_next_len_all = (arma::ones<arma::mat>(n_row-shift+1, shift-1))*(arma::diagmat(n_cells)); // In R: rep(vec, n_row times) // Schur product: element-wise multiplication of two objects
  153. for(int i=1; i<=(n_row-shift+1); i++)
  154. {
  155. // next = current + vertical_columns_next_ln values
  156. rad_mat_next_ln1.submat(i-1, 0, i-1, shift-2-1) = rad_mat_current_ln1( i-1, arma::span(0, shift-2-1) ) + vertical_columns_next_ln1( arma::span(0, shift-2-1), i-1 ).t();
  157. rad_mat_next_ln2.submat(i-1, 0, i-1, shift-2-1) = rad_mat_current_ln2( i-1, arma::span(0, shift-2-1) ) + vertical_columns_next_ln2( arma::span(0, shift-2-1), i-1 ).t();
  158. rad_mat_next_len.submat(i-1, 0, i-1, shift-2-1) = rad_mat_current_len( i-1, arma::span(0, shift-2-1) ) + vertical_columns_next_len( arma::span(0, shift-2-1), i-1 ).t();
  159. rad_mat_next_ln1(i-1, shift-1-1) = vertical_columns_next_ln1(shift-1-1, i-1); // the last new element
  160. rad_mat_next_ln2(i-1, shift-1-1) = vertical_columns_next_ln2(shift-1-1, i-1); // the last new element
  161. rad_mat_next_len(i-1, shift-1-1) = vertical_columns_next_len(shift-1-1, i-1); // the last new element
  162. }
  163. ////////////////////////////////////// compute the vertical_columns_next values
  164. if(shift < n_row) //stop when shift=n
  165. {
  166. vertical_columns_current_ln1 = vertical_columns_next_ln1;
  167. vertical_columns_current_ln2 = vertical_columns_next_ln2;
  168. vertical_columns_current_len = vertical_columns_next_len;
  169. arma::mat first_row_ln1(1, n_row-shift);
  170. arma::mat first_row_ln2(1, n_row-shift);
  171. arma::mat first_row_len(1, n_row-shift);
  172. for(int i=0; i<(n_row-shift); i++)
  173. {
  174. first_row_ln1(0, i) = A_ln1(i, i+shift); // off-diagonal values to be appended to vertical_columns_next_ln
  175. first_row_ln2(0, i) = A_ln2(i, i+shift); // off-diagonal values to be appended to vertical_columns_next_ln
  176. first_row_len(0, i) = A_len(i, i+shift); // off-diagonal values to be appended to vertical_columns_next_ln
  177. }
  178. vertical_columns_next_ln1 = vertical_columns_current_ln1.submat(0, 1, shift-2, n_row-shift); // drop the first column
  179. vertical_columns_next_ln2 = vertical_columns_current_ln2.submat(0, 1, shift-2, n_row-shift); // drop the first column
  180. vertical_columns_next_len = vertical_columns_current_len.submat(0, 1, shift-2, n_row-shift); // drop the first column
  181. vertical_columns_next_ln1 = arma::join_cols(first_row_ln1, vertical_columns_next_ln1);
  182. vertical_columns_next_ln2 = arma::join_cols(first_row_ln2, vertical_columns_next_ln2);
  183. vertical_columns_next_len = arma::join_cols(first_row_len, vertical_columns_next_len);
  184. for(int i=1; i<shift; i++)
  185. {
  186. vertical_columns_next_ln1.row(i) = vertical_columns_next_ln1.row(i) + first_row_ln1; // cumsum
  187. vertical_columns_next_ln2.row(i) = vertical_columns_next_ln2.row(i) + first_row_ln2; // cumsum
  188. vertical_columns_next_len.row(i) = vertical_columns_next_len.row(i) + first_row_len; // cumsum
  189. }
  190. }
  191. ////////////////////////////////////// compute the L and ancestors
  192. // if(shift >= 4)
  193. if(shift >= 2*min_n_bins) //
  194. {
  195. ps = rad_mat_next_len_all - rad_mat_next_len; // number of positive values
  196. loglik_tmp = loglik_lnorm_cpp_mat( rad_mat_next_ln1, rad_mat_next_ln2, ps, rad_mat_next_len );
  197. // loglik = loglik_tmp.submat(0, 1, n_row-shift, shift-2-1); // remove first and last col because of the min_n_bins. SHOULD BE MODIFIED. submat: X.submat( first_row, first_col, last_row, last_col ), http://arma.sourceforge.net/docs.html#submat
  198. loglik = loglik_tmp.submat(0, min_n_bins-1, n_row-shift, shift-2-(min_n_bins-1)); // 2018-11-14, remove first and last min_n_bins-1 cols because of the min_n_bins. SHOULD BE MODIFIED. submat: X.submat( first_row, first_col, last_row, last_col ), http://arma.sourceforge.net/docs.html#submat
  199. arma::mat cases(1, shift-2*min_n_bins+1); // shift=5: 1:2, i.e., two cases
  200. // arma::mat loglik(n_row-shift, shift-min_n_bins);
  201. for( int row=1; row<=(n_row-shift+1); row++ )
  202. {
  203. p = row; //7
  204. q = row + shift; // 7 + 4 = 11
  205. cases = loglik( p-1, arma::span(0, shift-2*min_n_bins) ) + L(p-1, arma::span(p-1-1 + min_n_bins, q - min_n_bins-1-1) ) + (L(arma::span(p-1 + min_n_bins, q - min_n_bins-1), q-1-1)).t();
  206. L(p-1, q-1-1) = cases.max();
  207. max_mid_index = cases.index_max() + 1; //The c++ offset 1
  208. max_mid = (p-1 + min_n_bins) + max_mid_index -1; // should minus one
  209. // ancestor = paste(i, max_mid, max_mid+1, j, sep='-')
  210. ancestors(p-1, q-1-1) = to_string(p) + "-" + to_string(max_mid) + "-" + to_string(max_mid+1) + "-" + to_string(q-1);
  211. // cout << "cases:" << L(p-1, q-1-1) << "\n";
  212. }
  213. }
  214. }
  215. res["ancestors"] = ancestors;
  216. res["L"] = L;
  217. return res;
  218. }
  219. // compute the ancestors
  220. // [[Rcpp::export]]
  221. List compute_L( arma::mat A, arma::mat L, int k ) // A seems not needed here, and can be removed, Y.L, 2018-11-14 (Indeed this part is not used)
  222. {
  223. int n_row = A.n_rows;
  224. int min_n_bins = 2;
  225. StringMatrix ancestors(n_row, n_row);
  226. // should rewrite this part
  227. // for( int i= min_n_bins-1; i<=(2*min_n_bins - 2); i++ )
  228. // for( int j=i; j<n_row; j++ ) L(j-i, j) = 1;
  229. for( int shift=(2*min_n_bins - 1); shift<=(n_row-1); shift++ )
  230. {
  231. int i, j, max_mid_index, max_mid;
  232. arma::mat cases(1, shift-2*min_n_bins+1);
  233. arma::mat prob_vec(n_row-shift, shift-min_n_bins);
  234. for( int row=1; row<=(n_row-shift); row++ )
  235. {
  236. i = row;
  237. j = row + shift;
  238. cases = prob_vec( i-1, arma::span(0, j-2*min_n_bins-i+1) ) + L(i-1, arma::span(i-1-1 + min_n_bins, j - min_n_bins-1) ) + (L(arma::span(i + min_n_bins-1, j - min_n_bins), j-1)).t();
  239. L(i-1, j-1) = cases.max();
  240. max_mid_index = cases.index_max() + 1; //The c++ offset 1
  241. max_mid = (i-1 + min_n_bins) + max_mid_index -1; // should minus one
  242. // ancestor = paste(i, max_mid, max_mid+1, j, sep='-')
  243. ancestors(i-1, j-1) = to_string(i) + "-" + to_string(max_mid) + "-" + to_string(max_mid+1) + "-" + to_string(j); // This 'to_string' may be optimized
  244. }
  245. }
  246. return List::create(Named("L") = L, Named("ancestors") = ancestors);
  247. }