planc
Parallel Lowrank Approximation with Non-negativity Constraints
distntfaoadmm.hpp
Go to the documentation of this file.
1 /* Copyright Ramakrishnan Kannan 2018 */
2 
3 #ifndef DISTNTF_DISTNTFAOADMM_HPP_
4 #define DISTNTF_DISTNTFAOADMM_HPP_
5 
6 #include "distntf/distauntf.hpp"
7 
8 namespace planc {
9 
10 class DistNTFAOADMM : public DistAUNTF {
11  private:
12  // ADMM auxiliary variables
13  NCPFactors m_local_ncp_aux;
14  NCPFactors m_local_ncp_aux_t;
15  NCPFactors m_temp_local_ncp_aux_t;
16  MAT L;
17  MAT Lt;
18  MAT tempgram;
19  int admm_iter;
20  double tolerance;
21  double chol_time;
22  double stop_iter_time;
23  double proj_time;
24  double solve_time;
25  double norm_time;
26 
27  protected:
35  MAT update(const int mode) {
36  // return variable
37  MAT updated_fac(this->m_local_ncp_factors.factor(mode));
38  MAT prev_fac = updated_fac;
39 
40  // Set up ADMM iteration
41  double alpha = 0.0;
42 
43  MPITIC;
44  if (m_nls_sizes[mode] > 0) {
45  alpha = arma::trace(this->global_gram) / this->m_local_ncp_factors.rank();
46  alpha = (alpha > 0) ? alpha : 0.01;
47  tempgram = this->global_gram;
48  tempgram.diag() += alpha;
49  L = arma::chol(tempgram, "lower");
50  Lt = L.t();
51  }
52  chol_time += MPITOC;
53  bool stop_iter = false;
54 
55  // Start ADMM loop from here
56  for (int i = 0; i < admm_iter && !stop_iter; i++) {
57  if (m_nls_sizes[mode] > 0) {
58  prev_fac = updated_fac;
59  m_local_ncp_aux_t.set(mode, m_local_ncp_aux.factor(mode).t());
60  MPITIC;
61  m_temp_local_ncp_aux_t.set(
62  mode, arma::solve(arma::trimatl(L),
63  this->ncp_local_mttkrp_t[mode] +
64  (alpha * (updated_fac.t() +
65  m_local_ncp_aux_t.factor(mode)))));
66  m_local_ncp_aux_t.set(mode,
67  arma::solve(arma::trimatu(Lt),
68  m_temp_local_ncp_aux_t.factor(mode)));
69  solve_time += MPITOC;
70  // Update factor matrix
71  updated_fac = m_local_ncp_aux_t.factor(mode).t();
72  MPITIC;
73  fixNumericalError<MAT>(&(updated_fac), EPSILON_1EMINUS16);
74  updated_fac = updated_fac - m_local_ncp_aux.factor(mode);
75  updated_fac.for_each(
76  [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
77  proj_time += MPITOC;
78  // Update dual variable
79  m_local_ncp_aux.set(mode, m_local_ncp_aux.factor(mode) + updated_fac -
80  m_local_ncp_aux_t.factor(mode).t());
81  }
82  // stopping criteria variables
83  double local_facnorm = 0.0;
84  double local_dualnorm = 0.0;
85  double r = 0.0;
86  double s = 0.0;
87  if (m_nls_sizes[mode] > 0) {
88  local_facnorm = arma::norm(updated_fac, "fro");
89  local_dualnorm = arma::norm(m_local_ncp_aux.factor(mode), "fro");
90  r = norm(updated_fac.t() - m_local_ncp_aux_t.factor(mode), "fro");
91  s = norm(updated_fac - prev_fac, "fro");
92  }
93  MPITIC;
94  // factor norm
95  local_facnorm *= local_facnorm;
96  double global_facnorm = 0.0;
97  MPI_Allreduce(&local_facnorm, &global_facnorm, 1, MPI_DOUBLE, MPI_SUM,
98  MPI_COMM_WORLD);
99  global_facnorm = sqrt(global_facnorm);
100 
101  // dual norm
102  local_dualnorm *= local_dualnorm;
103  double global_dualnorm = 0.0;
104  MPI_Allreduce(&local_dualnorm, &global_dualnorm, 1, MPI_DOUBLE, MPI_SUM,
105  MPI_COMM_WORLD);
106  global_dualnorm = sqrt(global_dualnorm);
107 
108  // Check stopping criteria (needs communication)
109  r *= r;
110  double global_r = 0.0;
111  MPI_Allreduce(&r, &global_r, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
112  global_r = sqrt(global_r);
113 
114  s *= s;
115  double global_s = 0.0;
116  MPI_Allreduce(&s, &global_s, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
117  global_s = sqrt(global_s);
118  if (global_r < (tolerance * global_facnorm) &&
119  global_s < (tolerance * global_dualnorm))
120  stop_iter = true;
121  stop_iter_time += MPITOC;
122  }
123  MPITIC;
124  m_local_ncp_aux.distributed_normalize(mode);
125  norm_time += MPITOC;
126  return updated_fac.t();
127  }
128 
129  public:
130  DistNTFAOADMM(const Tensor &i_tensor, const int i_k, algotype i_algo,
131  const UVEC &i_global_dims, const UVEC &i_local_dims,
132  const UVEC &i_nls_sizes, const UVEC &i_nls_idxs,
133  const NTFMPICommunicator &i_mpicomm)
134  : DistAUNTF(i_tensor, i_k, i_algo, i_global_dims, i_local_dims,
135  i_nls_sizes, i_nls_idxs, i_mpicomm),
136  m_local_ncp_aux(i_nls_sizes, i_k, false),
137  m_local_ncp_aux_t(i_nls_sizes, i_k, true),
138  m_temp_local_ncp_aux_t(i_nls_sizes, i_k, true) {
139  m_local_ncp_aux.zeros();
140  m_local_ncp_aux_t.zeros();
141  m_temp_local_ncp_aux_t.zeros();
142  L.zeros(i_k, i_k);
143  Lt.zeros(i_k, i_k);
144  tempgram.zeros(i_k, i_k);
145  admm_iter = 5;
146  tolerance = 0.01;
147  chol_time = 0.0;
148  stop_iter_time = 0.0;
149  proj_time = 0.0;
150  solve_time = 0.0;
151  norm_time = 0.0;
152  }
153 
155  PRINTROOT("::chol time::" << chol_time
156  << "::stop_iter_time::" << stop_iter_time
157  << "::proj_time::" << proj_time
158  << "::solve_time::" << solve_time
159  << "::norm_time::" << norm_time);
160  }
161 
162 }; // class DistNTFAOADMM
163 
164 } // namespace planc
165 
166 #endif // DISTNTF_DISTNTFAOADMM_HPP_
Data is stored such that the unfolding is column major.
Definition: tensor.hpp:32
#define EPSILON_1EMINUS16
Definition: utils.h:43
#define MPITIC
Definition: distutils.h:26
#define MPITOC
Definition: distutils.h:27
algotype
Definition: utils.h:10
#define UVEC
Definition: utils.h:58
void set(const int i_n, const MAT &i_factor)
Set the mode i_n with the given factor matrix.
Definition: ncpfactors.hpp:112
#define MAT
Definition: utils.h:52
void zeros()
this is for reinitializing zeros across different processors.
Definition: ncpfactors.hpp:382
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
DistNTFAOADMM(const Tensor &i_tensor, const int i_k, algotype i_algo, const UVEC &i_global_dims, const UVEC &i_local_dims, const UVEC &i_nls_sizes, const UVEC &i_nls_idxs, const NTFMPICommunicator &i_mpicomm)
#define PRINTROOT(MSG)
Definition: distutils.h:32
MAT & factor(const int i_n) const
factor matrix of a mode i_n
Definition: ncpfactors.hpp:100