planc
Parallel Lowrank Approximation with Non-negativity Constraints
distauntf.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 #ifndef DISTNTF_DISTAUNTF_HPP_
3 #define DISTNTF_DISTAUNTF_HPP_
4 
5 #include <armadillo>
6 #include <string>
7 #include <vector>
8 #include "common/distutils.hpp"
9 #include "common/ntf_utils.hpp"
10 #include "dimtree/ddt.hpp"
12 #include "distntf/distntftime.hpp"
13 
25 // #define DISTNTF_VERBOSE 1
26 
27 namespace planc {
28 
29 #define TENSOR_LOCAL_DIM (m_input_tensor.dimensions())
30 #define TENSOR_LOCAL_NUMEL (m_input_tensor.numel())
31 
32 class DistAUNTF {
33  protected:
34  // communication related variables
35  const NTFMPICommunicator &m_mpicomm;
36  // NLS solve sizes
37  UVEC m_nls_sizes;
38  UVEC m_nls_idxs;
39  // local ncp factors
40  NCPFactors m_local_ncp_factors;
41  NCPFactors m_local_ncp_factors_t;
42  // mttkrp related variables
43  MAT *ncp_mttkrp_t;
44  MAT *ncp_local_mttkrp_t;
45  // hadamard of the global_grams
46  MAT global_gram;
47 
48  virtual MAT update(int current_mode) = 0;
49 
50  private:
51  const Tensor &m_input_tensor;
52  NCPFactors m_gathered_ncp_factors;
53  NCPFactors m_gathered_ncp_factors_t;
54  // mttkrp related variables
55  MAT *ncp_krp;
56  // gram related variables.
57  MAT factor_local_grams; // U in the algorithm.
58  MAT *factor_global_grams; // G in the algorithm
59 
60  // NTF related variable.
61  const unsigned int m_low_rank_k;
62  const unsigned int m_modes;
63  const algotype m_updalgo;
64  const UVEC m_global_dims;
65  const UVEC m_factor_local_dims;
66  unsigned int m_num_it;
67  unsigned int current_mode;
68  FVEC m_regularizers;
69  bool m_compute_error;
70  bool m_enable_dim_tree;
71  unsigned int m_current_it;
72  double m_rel_error;
73 
74  // needed for acceleration algorithms.
75  bool m_accelerated;
76  std::vector<bool> m_stale_mttkrp;
77  // stats
78  DistNTFTime time_stats;
79 
80  // computing error related;
81  double m_global_sqnorm_A;
82  MAT hadamard_all_grams;
83 
84  DenseDimensionTree *kdt;
85 
92  void update_global_gram(const int current_mode) {
93  // computing U
94  MPITIC; // gram
95  // force a ssyrk instead of gemm.
96  MAT H = m_local_ncp_factors.factor(current_mode);
97  factor_local_grams = H.t() * H;
98 
99  double temp = MPITOC; // gram
100  this->time_stats.compute_duration(temp);
101  this->time_stats.gram_duration(temp);
102  factor_global_grams[current_mode].zeros();
103  // Computing G.
104  MPITIC; // allreduce gram
105  MPI_Allreduce(factor_local_grams.memptr(),
106  factor_global_grams[current_mode].memptr(),
107  this->m_low_rank_k * this->m_low_rank_k, MPI_DOUBLE, MPI_SUM,
108  MPI_COMM_WORLD);
109  temp = MPITOC; // allreduce gram
110  applyReg(this->m_regularizers(current_mode * 2),
111  this->m_regularizers(current_mode * 2 + 1),
112  &(factor_global_grams[current_mode]));
113  this->time_stats.communication_duration(temp);
114  this->time_stats.allreduce_duration(temp);
115  }
116 
124  void applyReg(float lambda_l2, float lambda_l1, MAT *AtA) {
125  // Frobenius norm regularization
126  if (lambda_l2 > 0) {
127  MAT identity = arma::eye<MAT>(this->m_low_rank_k, this->m_low_rank_k);
128  (*AtA) = (*AtA) + 2 * lambda_l2 * identity;
129  }
130 
131  // L1 - norm regularization
132  if (lambda_l1 > 0) {
133  MAT onematrix = arma::ones<MAT>(this->m_low_rank_k, this->m_low_rank_k);
134  (*AtA) = (*AtA) + 2 * lambda_l1 * onematrix;
135  }
136  }
137 
144  void gram_hadamard(unsigned int current_mode) {
145  global_gram.ones();
146  MPITIC; // gram hadamard
147  for (unsigned int i = 0; i < m_modes; i++) {
148  if (i != current_mode) {
149  //%= element-wise multiplication
150  global_gram %= factor_global_grams[i];
151  }
152  }
153  double temp = MPITOC; // gram hadamard
154  this->time_stats.compute_duration(temp);
155  this->time_stats.gram_duration(temp);
156  }
163  //
164  void gather_ncp_factor(const int current_mode) {
165  m_gathered_ncp_factors_t.factor(current_mode).zeros();
166  // Had this comment for debugging memory corruption in all_gather
167  // DISTPRINTINFO("::ncp_krp::" << ncp_krp[current_mode].memptr()
168  // << "::size::" << ncp_krp[current_mode].n_rows
169  // << "x" << ncp_krp[current_mode].n_cols
170  // << "::m_gathered_ncp_factors_t::"
171  // << m_gathered_ncp_factors_t.factor(current_mode).memptr()
172  // << "::diff from recvcnt::"
173  // << m_gathered_ncp_factors_t.factor(current_mode).memptr()
174  // - recvcnt * 8);
175 
176  MPI_Comm current_slice_comm = this->m_mpicomm.slice(current_mode);
177  int slice_size;
178  MPI_Comm_size(current_slice_comm, &slice_size);
179 
180  // int sendcnt = m_local_ncp_factors.factor(current_mode).n_elem;
181  int sendcnt = m_nls_sizes[current_mode] * m_low_rank_k;
182 
183  // int recvcnt = m_local_ncp_factors.factor(current_mode).n_elem;
184  std::vector<int> recvgathercnt(slice_size, 0);
185  std::vector<int> recvgatherdispl(slice_size, 0);
186 
187  int dimsize = m_factor_local_dims[current_mode];
188  for (int i = 0; i < slice_size; i++) {
189  recvgathercnt[i] = itersplit(dimsize, slice_size, i) * m_low_rank_k;
190  recvgatherdispl[i] = startidx(dimsize, slice_size, i) * m_low_rank_k;
191  }
192 
193 #ifdef DISTNTF_VERBOSE
194  MPI_Comm current_fiber_comm = this->m_mpicomm.fiber(current_mode);
195  int fiber_size;
196 
197  MPI_Comm_size(current_fiber_comm, &fiber_size);
198  MPI_Comm_size(current_slice_comm, &slice_size);
199  DISTPRINTINFO("::current_mode::"
200  << current_mode << "::fiber comm size::" << fiber_size
201  << "::my_global_rank::" << MPI_RANK << "::my_slice_rank::"
202  << this->m_mpicomm.slice_rank(current_mode)
203  << "::my_fiber_rank::"
204  << this->m_mpicomm.fiber_rank(current_mode)
205  << "::sendcnt::" << sendcnt << "::gathered factor size::"
206  << m_gathered_ncp_factors_t.factor(current_mode).n_elem);
207 #endif
208  MPITIC; // allgather tic
209  MPI_Allgatherv(m_local_ncp_factors_t.factor(current_mode).memptr(), sendcnt,
210  MPI_DOUBLE,
211  m_gathered_ncp_factors_t.factor(current_mode).memptr(),
212  &recvgathercnt[0], &recvgatherdispl[0], MPI_DOUBLE,
213  // todo:: check whether it is slice or fiber while running
214  // and debugging the code.
215  current_slice_comm);
216  // current_slice_comm);
217  double temp = MPITOC; // allgather toc
218  this->time_stats.communication_duration(temp);
219  this->time_stats.allgather_duration(temp);
220 #ifdef DISTNTF_VERBOSE
221  DISTPRINTINFO("sent local factor::"
222  << std::endl
223  << m_local_ncp_factors_t.factor(current_mode) << std::endl
224  << " gathered factor::" << std::endl
225  << m_gathered_ncp_factors_t.factor(current_mode));
226 #endif
227  // keep gather_ncp_factors_t consistent.
228  MPITIC; // transpose tic
229  m_gathered_ncp_factors.set(
230  current_mode, m_gathered_ncp_factors_t.factor(current_mode).t());
231  temp = MPITOC; // transpose toc
232  this->time_stats.compute_duration(temp);
233  this->time_stats.trans_duration(temp);
234  }
235 
243  void distmttkrp(const int &current_mode) {
244  double temp;
245  if (!this->m_enable_dim_tree) {
246  MPITIC; // krp tic
247  m_gathered_ncp_factors.krp_leave_out_one(current_mode,
248  &ncp_krp[current_mode]);
249  temp = MPITOC; // krp toc
250  this->time_stats.compute_duration(temp);
251  this->time_stats.krp_duration(temp);
252  }
253 
254  if (this->m_enable_dim_tree) {
255  double multittv_time = 0;
256  double mttkrp_time = 0;
257  kdt->in_order_reuse_MTTKRP(current_mode,
258  ncp_mttkrp_t[current_mode].memptr(), false,
259  multittv_time, mttkrp_time);
260  this->time_stats.compute_duration(multittv_time);
261  this->time_stats.compute_duration(mttkrp_time);
262  this->time_stats.multittv_duration(multittv_time);
263  this->time_stats.mttkrp_duration(mttkrp_time);
264 
265  } else {
266  MPITIC; // mttkrp tic
267  m_input_tensor.mttkrp(current_mode, ncp_krp[current_mode],
268  &ncp_mttkrp_t[current_mode]);
269  temp = MPITOC; // mttkrp toc
270  this->time_stats.compute_duration(temp);
271  this->time_stats.mttkrp_duration(temp);
272  }
273  // verify if the dimension tree output matches with the classic one
274  // MAT kdt_ncp_mttkrp_t = ncp_mttkrp_t[current_mode];
275  // bool same_mttkrp = arma::approx_equal(kdt_ncp_mttkrp_t,
276  // ncp_mttkrp_t[current_mode], "absdiff", 1e-3); PRINTROOT("kdt vs
277  // mttkrp_t::" << same_mttkrp); MAT ncp_mttkrp =
278  // ncp_mttkrp_t[current_mode].t(); same_mttkrp =
279  // arma::approx_equal(kdt_ncp_mttkrp_t, ncp_mttkrp, "absdiff", 1e-3);
280  // PRINTROOT("kdt vs mttkrp::" << same_mttkrp);
281  // PRINTROOT("kdt mttkrp::" << kdt_ncp_mttkrp_t);
282  // PRINTROOT("classic mttkrp_t::" << ncp_mttkrp_t[current_mode]);
283 
284  MPI_Comm current_slice_comm = this->m_mpicomm.slice(current_mode);
285  int slice_size;
286  int slice_rank;
287  MPI_Comm_size(current_slice_comm, &slice_size);
288  slice_rank = this->m_mpicomm.slice_rank(current_mode);
289  std::vector<int> recvmttkrpsize(slice_size);
290  int dimsize = m_factor_local_dims[current_mode];
291  for (int i = 0; i < slice_size; i++) {
292  recvmttkrpsize[i] = itersplit(dimsize, slice_size, i) * m_low_rank_k;
293  }
294 #ifdef DISTNTF_VERBOSE
295  MPI_Comm current_fiber_comm = this->m_mpicomm.fiber(current_mode);
296  int fiber_size;
297  MPI_Comm_size(current_fiber_comm, &fiber_size);
298  DISTPRINTINFO("::current_mode::"
299  << current_mode << "::slice comm size::" << slice_size
300  << "::fiber comm size::" << fiber_size
301  << "::my_global_rank::" << MPI_RANK << "::my_slice_rank::"
302  << this->m_mpicomm.slice_rank(current_mode)
303  << "::my_fiber_rank::"
304  << this->m_mpicomm.fiber_rank(current_mode)
305  << "::mttkrp_size::" << ncp_mttkrp_t[current_mode].n_elem
306  << "::local_mttkrp_size::"
307  << ncp_local_mttkrp_t[current_mode].n_elem);
308 #endif
309  ncp_local_mttkrp_t[current_mode].zeros();
310  MPITIC; // reduce_scatter mttkrp
311  MPI_Reduce_scatter(ncp_mttkrp_t[current_mode].memptr(),
312  ncp_local_mttkrp_t[current_mode].memptr(),
313  &recvmttkrpsize[0], MPI_DOUBLE, MPI_SUM,
314  current_slice_comm);
315  temp = MPITOC; // reduce_scatter mttkrp
316  this->time_stats.communication_duration(temp);
317  this->time_stats.reducescatter_duration(temp);
318 #ifdef DISTNTF_VERBOSE
319  DISTPRINTINFO(ncp_mttkrp_t[current_mode]);
320  DISTPRINTINFO(ncp_local_mttkrp_t[current_mode]);
321 #endif
322  this->m_stale_mttkrp[current_mode] = false;
323  }
324 
325  void allocateMatrices() {
326  // allocate matrices.
327  if (!m_enable_dim_tree) {
328  ncp_krp = new MAT[m_modes];
329  }
330  ncp_mttkrp_t = new MAT[m_modes];
331  ncp_local_mttkrp_t = new MAT[m_modes];
332  factor_global_grams = new MAT[m_modes];
333  factor_local_grams.zeros(this->m_low_rank_k, this->m_low_rank_k);
334  global_gram.ones(this->m_low_rank_k, this->m_low_rank_k);
335  UWORD current_size = 0;
336  for (unsigned int i = 0; i < m_modes; i++) {
337  current_size = TENSOR_LOCAL_NUMEL / TENSOR_LOCAL_DIM[i];
338  if (!m_enable_dim_tree) {
339  ncp_krp[i] = arma::zeros(current_size, this->m_low_rank_k);
340  }
341  ncp_mttkrp_t[i] = arma::zeros(this->m_low_rank_k, TENSOR_LOCAL_DIM[i]);
342  ncp_local_mttkrp_t[i] = arma::zeros(m_local_ncp_factors.factor(i).n_cols,
343  m_local_ncp_factors.factor(i).n_rows);
344  factor_global_grams[i] =
345  arma::zeros(this->m_low_rank_k, this->m_low_rank_k);
346  }
347  }
348 
349  void freeMatrices() {
350  for (unsigned int i = 0; i < m_modes; i++) {
351  if (!m_enable_dim_tree) {
352  ncp_krp[i].clear();
353  }
354  ncp_mttkrp_t[i].clear();
355  ncp_local_mttkrp_t[i].clear();
356  factor_global_grams[i].clear();
357  }
358  if (!m_enable_dim_tree) {
359  delete[] ncp_krp;
360  }
361  delete[] ncp_mttkrp_t;
362  delete[] ncp_local_mttkrp_t;
363  delete[] factor_global_grams;
364  }
365 
366  void reportTime(const double temp, const std::string &reportstring) {
367  double mintemp, maxtemp, sumtemp;
368  MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
369  MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
370  MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
371  PRINTROOT(reportstring // << "::dims::" << this->m_global_dims.t()
372  << "::k::" << this->m_low_rank_k << "::SIZE::" << MPI_SIZE
373  << "::algo::" << this->m_updalgo << "::root::" << temp
374  << "::min::" << mintemp << "::avg::" << (sumtemp) / (MPI_SIZE)
375  << "::max::" << maxtemp);
376  }
377 
387  void update_factor_mode(const unsigned int current_mode, const MAT &factor) {
388  m_local_ncp_factors.set(current_mode, factor);
389  m_local_ncp_factors.distributed_normalize(current_mode);
390  MAT factor_t = m_local_ncp_factors.factor(current_mode).t();
391  m_local_ncp_factors_t.set(current_mode, factor_t);
392  m_local_ncp_factors_t.set_lambda(m_local_ncp_factors.lambda());
393  // line 13 and 14
394  update_global_gram(current_mode);
395  // line 15
396  gather_ncp_factor(current_mode);
397  if (this->m_enable_dim_tree) {
398  kdt->set_factor(m_gathered_ncp_factors_t.factor(current_mode).memptr(),
399  current_mode);
400  }
401  for (unsigned int mode = 0; mode < this->m_modes; mode++) {
402  if (mode != current_mode) this->m_stale_mttkrp[mode] = true;
403  }
404  }
405 
406  virtual void accelerate() {}
407 
408  void generateReport() {
409  MPI_Barrier(MPI_COMM_WORLD);
410  this->reportTime(this->time_stats.duration(), "total_d");
411  this->reportTime(this->time_stats.communication_duration(), "total_comm");
412  this->reportTime(this->time_stats.compute_duration(), "total_comp");
413  this->reportTime(this->time_stats.allgather_duration(), "total_allgather");
414  this->reportTime(this->time_stats.allreduce_duration(), "total_allreduce");
415  this->reportTime(this->time_stats.reducescatter_duration(),
416  "total_reducescatter");
417  this->reportTime(this->time_stats.gram_duration(), "total_gram");
418  this->reportTime(this->time_stats.krp_duration(), "total_krp");
419  this->reportTime(this->time_stats.mttkrp_duration(), "total_mttkrp");
420  this->reportTime(this->time_stats.multittv_duration(), "total_multittv");
421  this->reportTime(this->time_stats.nnls_duration(), "total_nnls");
422  if (this->m_compute_error) {
423  this->reportTime(this->time_stats.err_compute_duration(),
424  "total_err_compute");
425  this->reportTime(this->time_stats.err_compute_duration(),
426  "total_err_communication");
427  }
428  }
429 
430  public:
431  DistAUNTF(const Tensor &i_tensor, const int i_k, algotype i_algo,
432  const UVEC &i_global_dims, const UVEC &i_local_dims,
433  const UVEC &i_nls_sizes, const UVEC &i_nls_idxs,
434  const NTFMPICommunicator &i_mpicomm)
435  : m_mpicomm(i_mpicomm),
436  m_nls_sizes(i_nls_sizes),
437  m_nls_idxs(i_nls_idxs),
438  m_local_ncp_factors(i_nls_sizes, i_k, false),
439  m_local_ncp_factors_t(i_nls_sizes, i_k, true),
440  m_input_tensor(i_tensor),
441  m_gathered_ncp_factors(i_tensor.dimensions(), i_k, false),
442  m_gathered_ncp_factors_t(i_tensor.dimensions(), i_k, true),
443  m_low_rank_k(i_k),
444  m_modes(m_input_tensor.modes()),
445  m_updalgo(i_algo),
446  m_global_dims(i_global_dims),
447  m_factor_local_dims(i_local_dims),
448  time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
449  this->m_compute_error = false;
450  this->m_enable_dim_tree = false;
451  this->m_accelerated = false;
452  this->m_num_it = 30;
453  this->m_rel_error = 1.0;
454  // randomize again. otherwise all the process and factors
455  // will be same.
456  m_local_ncp_factors.randu(149 * i_mpicomm.rank() + 103);
457  m_local_ncp_factors.distributed_normalize();
458  for (unsigned int i = 0; i < this->m_modes; i++) {
459  MAT current_factor = arma::trans(m_local_ncp_factors.factor(i));
460  m_local_ncp_factors_t.set(i, current_factor);
461  this->m_stale_mttkrp.push_back(true);
462  }
463  m_local_ncp_factors_t.set_lambda(m_local_ncp_factors.lambda());
464  m_gathered_ncp_factors.trans(m_gathered_ncp_factors_t);
465  allocateMatrices();
466  double normA = i_tensor.norm();
467  MPI_Allreduce(&normA, &this->m_global_sqnorm_A, 1, MPI_DOUBLE, MPI_SUM,
468  MPI_COMM_WORLD);
469 
470  DISTPRINTINFO("::NLS Solve Sizes::"
471  << m_nls_sizes << "::NLS start indices::" << m_nls_idxs);
472  }
474  freeMatrices();
475  if (this->m_enable_dim_tree) {
476  delete kdt;
477  }
478  }
480  void num_iterations(const int i_n) { this->m_num_it = i_n; }
482  size_t modes() const { return this->m_modes; }
484  size_t rank() const { return this->m_low_rank_k; }
486  void regularizers(const FVEC i_regs) { this->m_regularizers = i_regs; }
488  void compute_error(bool i_error) {
489  this->m_compute_error = i_error;
490  hadamard_all_grams =
491  arma::ones<MAT>(this->m_low_rank_k, this->m_low_rank_k);
492  }
494  VEC lambda() { return m_local_ncp_factors.lambda(); }
496  int current_it() const { return this->m_current_it; }
498  double current_error() const { return this->m_rel_error; }
501  void dim_tree(bool i_dim_tree) {
502  this->m_enable_dim_tree = i_dim_tree;
503  if (this->m_enable_dim_tree) {
504  if (this->ncp_krp != NULL) {
505  for (unsigned int i = 0; i < m_modes; i++) {
506  ncp_krp[i].clear();
507  }
508  delete[] ncp_krp;
509  }
510  }
511  }
513  void accelerated(const bool &set_acceleration) {
514  this->m_accelerated = set_acceleration;
515  this->m_compute_error = true;
516  }
517 
518  bool is_stale_mttkrp(const int &current_mode) const {
519  return this->m_stale_mttkrp[current_mode];
520  }
521 
529  void reset(const NCPFactors &new_factors, bool trans = false) {
530  if (!trans) {
531  for (unsigned int i = 0; i < m_modes; i++) {
532  update_factor_mode(i, new_factors.factor(i));
533  }
534  } else {
535  for (unsigned int i = 0; i < m_modes; i++) {
536  update_factor_mode(i, new_factors.factor(i).t());
537  }
538  }
539  m_local_ncp_factors.set_lambda(new_factors.lambda());
540  m_local_ncp_factors_t.set_lambda(new_factors.lambda());
541  }
542 
543  // Preferrably call this after the computeNTF().
544  // This is right now called to save the factor matrices.
552  void factor(int mode, double *factor_matrix) {
553  gather_ncp_factor(mode);
554  int sendcnt = m_gathered_ncp_factors_t.factor(mode).n_elem;
555  int fiber_size = this->m_mpicomm.proc_grids()[mode];
556  int global_size = this->m_global_dims[mode];
557  std::vector<int> recvcnts(fiber_size, 0);
558  std::vector<int> displs(fiber_size, 0);
559 
560  DISTPRINTINFO("Collecting mode::" << mode << "::sendcnt::" << sendcnt
561  << "::fiber_size::" << fiber_size
562  << "::global_size::" << global_size);
563 
564  // int dimsize = m_factor_local_dims[current_mode];
565  for (int i = 0; i < fiber_size; i++) {
566  recvcnts[i] = itersplit(global_size, fiber_size, i) * m_low_rank_k;
567  displs[i] = startidx(global_size, fiber_size, i) * m_low_rank_k;
568  }
569 
570  MPI_Gatherv(m_gathered_ncp_factors_t.factor(mode).memptr(), sendcnt,
571  MPI_DOUBLE, factor_matrix, &recvcnts[0], &displs[0], MPI_DOUBLE,
572  // todo:: check whether it is slice or fiber while running
573  // and debugging the code.
574  0, this->m_mpicomm.fiber(mode));
575  }
576 
578  void computeNTF() {
579  // initialize everything.
580  // line 3,4,5 of the algorithm
581  for (unsigned int i = 1; i < m_modes; i++) {
582  update_global_gram(i);
583  gather_ncp_factor(i);
584  }
585  if (this->m_enable_dim_tree) {
586  // Determine optimial split when given mode ordering.
587  // product of left dimensions \approx product of right dimensions.
588  size_t split_criteria = arma::prod(m_input_tensor.dimensions());
589  split_criteria = std::round(std::sqrt(split_criteria));
590  UVEC temp_cum_prod = arma::cumprod(m_input_tensor.dimensions());
591  int split_mode = 0;
592  while (temp_cum_prod(split_mode) < split_criteria) {
593  split_mode++;
594  }
595  PRINTROOT("KDT Split Mode::" << split_mode
596  << "::split criteria::" << split_criteria
597  << "::cum prod::" << std::endl
598  << temp_cum_prod << std::endl);
599 
600  // check to see if split mode is left or right.
601  if (split_mode > 0) {
602  size_t current_left = temp_cum_prod(split_mode);
603  size_t good_criteria = temp_cum_prod(temp_cum_prod.n_rows - 1) /
604  temp_cum_prod(split_mode - 1);
605  if (current_left > good_criteria) split_mode--;
606  PRINTROOT("KDT Split Mode::"
607  << split_mode << "::split criteria::" << split_criteria
608  << "::numerator::" << temp_cum_prod(temp_cum_prod.n_rows - 1)
609  << "::good_criteria::" << good_criteria
610  << "::current_left::" << current_left << std::endl
611  << "::cum prod::" << std::endl
612  << temp_cum_prod << std::endl);
613  }
614  kdt = new DenseDimensionTree(m_input_tensor, m_gathered_ncp_factors,
615  split_mode);
616  }
617 #ifdef DISTNTF_VERBOSE
618  DISTPRINTINFO("local factor matrices::");
619  this->m_local_ncp_factors.print();
620  DISTPRINTINFO("local factor matrices transpose::");
621  this->m_local_ncp_factors_t.print();
622  DISTPRINTINFO("gathered factor matrices::");
623  this->m_gathered_ncp_factors.print();
624 #endif
625  for (this->m_current_it = 0; this->m_current_it < m_num_it;
626  this->m_current_it++) {
627  MAT unnorm_factor;
628  for (unsigned int current_mode = 0; current_mode < m_modes;
629  current_mode++) {
630  // line 9 and 10 of the algorithm
631  if (is_stale_mttkrp(current_mode)) distmttkrp(current_mode);
632  // line 11 of the algorithm
633  gram_hadamard(current_mode);
634  // line 12 of the algorithm
635 #ifdef DISTNTF_VERBOSE
636  DISTPRINTINFO("local factor matrix::"
637  << this->m_local_ncp_factors.factor(current_mode));
638  DISTPRINTINFO("gathered factor matrix::");
639  this->m_gathered_ncp_factors.print();
640  PRINTROOT("global_grams::" << std::endl << this->global_gram);
641  DISTPRINTINFO("mttkrp::");
642  this->ncp_local_mttkrp_t[current_mode].print();
643 #endif
644  MPITIC; // nnls_tic
645  MAT factor = update(current_mode);
646  double temp = MPITOC; // nnls_toc
647  this->time_stats.compute_duration(temp);
648  this->time_stats.nnls_duration(temp);
649 #ifdef DISTNTF_VERBOSE
650  DISTPRINTINFO("it::" << this->m_current_it << "::mode::" << current_mode
651  << std::endl
652  << factor);
653 #endif
654  if (m_compute_error && current_mode == this->m_modes - 1) {
655  unnorm_factor = factor;
656  }
657  update_factor_mode(current_mode, factor.t());
658  }
659  if (m_compute_error) {
660  double temp_err = computeError(unnorm_factor, this->m_modes - 1);
661  this->m_rel_error = temp_err;
662  double iter_time = this->time_stats.compute_duration() +
663  this->time_stats.communication_duration();
664  PRINTROOT("Iter::" << this->m_current_it << "::k::"
665  << this->m_low_rank_k << "::SIZE::" << MPI_SIZE
666  << "::algo::" << this->m_updalgo << "::time::"
667  << iter_time << "::relative_error::" << temp_err);
668  }
669  if (this->m_accelerated) {
670  // there is a acceleration possible. call accelerate method
671  // in the derived class.
672  accelerate();
673  }
674  PRINTROOT("completed it::" << this->m_current_it);
675  }
676  generateReport();
677  }
685  double computeError(const MAT &unnorm_factor, int mode) {
686  // rel_Error = sqrt(max(init.nr_X^2 + lambda^T * Hadamard of all gram *
687  // lambda - 2 * innerprod(X,F_kten),0))/init.nr_X;
688  MPITIC; // err compute
689  hadamard_all_grams = global_gram % factor_global_grams[mode];
690  VEC local_lambda = m_local_ncp_factors.lambda();
691  ROWVEC temp_vec = local_lambda.t() * hadamard_all_grams;
692  double sq_norm_model = arma::dot(temp_vec, local_lambda);
693  // double sq_norm_model = arma::norm(hadamard_all_grams, "fro");
694  // sum of the element-wise dot product between the local mttkrp and
695  // the factor matrix
696  double inner_product = arma::dot(ncp_local_mttkrp_t[mode], unnorm_factor);
697  double temp = MPITOC; // err compute
698  this->time_stats.compute_duration(temp);
699  this->time_stats.err_compute_duration(temp);
700  double all_inner_product;
701  MPITIC; // err comm
702  MPI_Allreduce(&inner_product, &all_inner_product, 1, MPI_DOUBLE, MPI_SUM,
703  MPI_COMM_WORLD);
704  temp = MPITOC; // err comm
705  this->time_stats.communication_duration(temp);
706  this->time_stats.err_communication_duration(temp);
707 #ifdef DISTNTF_VERBOSE
708  DISTPRINTINFO("local_lambda::" << local_lambda);
709  DISTPRINTINFO("local_inner_product::" << inner_product << std::endl);
710  PRINTROOT("norm_A_sq :: "
711  << this->m_global_sqnorm_A << "::model_norm_sq::" << sq_norm_model
712  << "::global_inner_product::" << all_inner_product << std::endl);
713 #endif
714  double squared_err =
715  this->m_global_sqnorm_A + sq_norm_model - 2 * all_inner_product;
716  if (squared_err < 0) {
717  PRINTROOT("computed error is negative due to round off");
718  PRINTROOT("norm_A_sq :: "
719  << this->m_global_sqnorm_A
720  << "::model_norm_sq::" << sq_norm_model
721  << "::global_inner_product::" << all_inner_product
722  << "::squared_err::" << squared_err << std::endl);
723  }
724  return std::sqrt(std::abs(squared_err) / this->m_global_sqnorm_A);
725  }
731  double computeError(const NCPFactors &new_factors_t, const int mode) {
732  // rel_Error = sqrt(max(init.nr_X^2 + lambda^T * Hadamard of all gram *
733  // lambda - 2 * innerprod(X,F_kten),0))/init.nr_X;
734  // Reset with new factors and compute error on mode 0
735  reset(new_factors_t, true);
736  distmttkrp(mode);
737  gram_hadamard(mode);
738  hadamard_all_grams = global_gram % factor_global_grams[mode];
739  VEC local_lambda = m_local_ncp_factors.lambda();
740  MAT unnorm_factor =
741  arma::diagmat(local_lambda) * new_factors_t.factor(mode);
742  ROWVEC temp_vec = local_lambda.t() * hadamard_all_grams;
743  double sq_norm_model = arma::dot(temp_vec, local_lambda);
744  // double sq_norm_model = arma::norm(hadamard_all_grams, "fro");
745  // sum of the element-wise dot product between the local mttkrp and
746  // the factor matrix
747  double inner_product = arma::dot(ncp_local_mttkrp_t[mode], unnorm_factor);
748  double all_inner_product;
749  MPI_Allreduce(&inner_product, &all_inner_product, 1, MPI_DOUBLE, MPI_SUM,
750  MPI_COMM_WORLD);
751  double squared_err =
752  this->m_global_sqnorm_A + sq_norm_model - 2 * all_inner_product;
753  if (squared_err < 0) {
754  PRINTROOT("computed error is negative due to round off");
755  PRINTROOT("norm_A_sq :: "
756  << this->m_global_sqnorm_A
757  << "::model_norm_sq::" << sq_norm_model
758  << "::global_inner_product::" << all_inner_product
759  << "::squared_err::" << squared_err << std::endl);
760  }
761  return std::sqrt(std::abs(squared_err) / this->m_global_sqnorm_A);
762  }
763 }; // class DistAUNTF
764 } // namespace planc
765 #endif // DISTNTF_DISTAUNTF_HPP_
size_t rank() const
Low Rank.
Definition: distauntf.hpp:484
void randu(const int i_seed)
initializes the local tensor with the given seed.
Definition: ncpfactors.hpp:371
const double err_communication_duration() const
Definition: distntftime.hpp:91
void accelerated(const bool &set_acceleration)
Does the algorithm need acceleration?
Definition: distauntf.hpp:513
Data is stored such that the unfolding is column major.
Definition: tensor.hpp:32
void compute_error(bool i_error)
Sets whether to compute the error or not.
Definition: distauntf.hpp:488
int rank(const int *i_coords) const
Returns the rank of current MPI process given the cartesian coordinates.
const double nnls_duration() const
Definition: distntftime.hpp:89
const double gram_duration() const
Definition: distntftime.hpp:85
int startidx(int n, int p, int r)
Returns the start idx of the current rank r for a global dimension n across p processes.
Definition: distutils.hpp:90
const double duration() const
Definition: distntftime.hpp:75
void print()
prints the entire NCPFactors including the factor matrices
Definition: ncpfactors.hpp:302
const MPI_Comm & slice(const int i) const
Returns the slice communicator.
const double mttkrp_duration() const
Definition: distntftime.hpp:87
void set_factor(const double *arma_factor_ptr, const long int mode)
Definition: ddt.hpp:93
bool is_stale_mttkrp(const int &current_mode) const
Definition: distauntf.hpp:518
#define MPITIC
Definition: distutils.h:26
void trans(NCPFactors &factor_t)
Transposes the entire factor matrix.
Definition: ncpfactors.hpp:323
#define FVEC
Definition: utils.h:55
const double allreduce_duration() const
Definition: distntftime.hpp:81
#define MPITOC
Definition: distutils.h:27
algotype
Definition: utils.h:10
void mttkrp(const int i_n, const MAT &i_krp, MAT *o_mttkrp) const
size of krp must be product of all dimensions leaving out nxk.
Definition: tensor.hpp:242
void reset(const NCPFactors &new_factors, bool trans=false)
This function will completely reset all the factors and the state of AUNTF.
Definition: distauntf.hpp:529
void num_iterations(const int i_n)
Returns number of iterations.
Definition: distauntf.hpp:480
#define UVEC
Definition: utils.h:58
#define TENSOR_LOCAL_DIM
Definition: distauntf.hpp:29
#define DISTPRINTINFO(MSG)
Definition: distutils.h:37
int itersplit(int n, int p, int r)
The dimension a particular rank holds out of the global dimension n across p processes.
Definition: distutils.hpp:78
MAT krp_leave_out_one(const unsigned int i_n)
KRP leaving out the mode i_n.
Definition: ncpfactors.hpp:154
const double multittv_duration() const
Definition: distntftime.hpp:88
const double reducescatter_duration() const
Definition: distntftime.hpp:82
VEC lambda() const
returns the lambda vector
Definition: ncpfactors.hpp:104
const double compute_duration() const
Definition: distntftime.hpp:76
void in_order_reuse_MTTKRP(long int n, double *out, bool colmajor, double &multittv_time, double &mttkrp_time)
Definition: ddt.hpp:134
UVEC proc_grids() const
Returns the process grid for which the communicators are setup.
void computeNTF()
The main computeNTF loop.
Definition: distauntf.hpp:578
const double allgather_duration() const
Definition: distntftime.hpp:80
void regularizers(const FVEC i_regs)
L1 and L2 Regularization for every mode.
Definition: distauntf.hpp:486
void set(const int i_n, const MAT &i_factor)
Set the mode i_n with the given factor matrix.
Definition: ncpfactors.hpp:112
int current_it() const
Returns the current outer iteration of the computeNTF.
Definition: distauntf.hpp:496
#define MPI_RANK
Definition: distutils.h:16
void set_lambda(const VEC &new_lambda)
sets the lambda vector
Definition: ncpfactors.hpp:117
double current_error() const
Returns the current error.
Definition: distauntf.hpp:498
const double krp_duration() const
Definition: distntftime.hpp:86
const MPI_Comm & fiber(const int i) const
Returns the fiber communicator.
#define UWORD
Definition: utils.h:60
const double trans_duration() const
Definition: distntftime.hpp:94
int fiber_rank(int i) const
Returns the fiber rank on a particular fiber grid.
DistAUNTF(const Tensor &i_tensor, const int i_k, algotype i_algo, const UVEC &i_global_dims, const UVEC &i_local_dims, const UVEC &i_nls_sizes, const UVEC &i_nls_idxs, const NTFMPICommunicator &i_mpicomm)
Definition: distauntf.hpp:431
#define MAT
Definition: utils.h:52
UVEC dimensions() const
Returns a vector of dimensions on every mode.
Definition: tensor.hpp:161
const double err_compute_duration() const
Definition: distntftime.hpp:90
int slice_rank(int i) const
Returns the slice rank on a particular slice grid.
double norm() const
returns the frobenius norm of the tensor
Definition: tensor.hpp:346
size_t modes() const
Returns the numbers of modes of the tensor.
Definition: distauntf.hpp:482
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define VEC
Definition: utils.h:61
const double communication_duration() const
Definition: distntftime.hpp:77
void factor(int mode, double *factor_matrix)
Returns the factor matrix by collected it across all the processors.
Definition: distauntf.hpp:552
MAT & factor(const int i_n) const
factor matrix of a mode i_n
Definition: ncpfactors.hpp:100
#define TENSOR_LOCAL_NUMEL
Definition: distauntf.hpp:30
void dim_tree(bool i_dim_tree)
MTTKRP can be computed with or without dimension trees.
Definition: distauntf.hpp:501
#define MPI_SIZE
Definition: distutils.h:15
VEC lambda()
Returns the lambda of the NCP factors.
Definition: distauntf.hpp:494