planc
Parallel Lowrank Approximation with Non-negativity Constraints
naiveanlsbpp.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_NAIVE_ANLS_BPP_HPP_
4 #define DISTNMF_NAIVE_ANLS_BPP_HPP_
5 #pragma once
6 #include "distnmf/distnmf1D.hpp"
7 #include "nnls/bppnnls.hpp"
8 
9 namespace planc {
10 
11 template <class INPUTMATTYPE>
12 class DistNaiveANLSBPP : public DistNMF1D<INPUTMATTYPE> {
13  MAT HtH, WtW;
14  MAT AcolstW, ArowsH;
15  MAT ArowsHt, AcolstWt;
16  ROWVEC localWnorm;
17  ROWVEC Wnorm;
18 
19  void printConfig() {
20  PRINTROOT("NAIVEANLSBPP constructor completed::"
21  << "::A::" << this->m_globalm << "x" << this->m_globaln
22  << "::norm::" << this->m_globalsqnormA << "::k::" << this->m_k
23  << "::it::" << this->m_num_iterations);
24  }
25 
26  public:
27  DistNaiveANLSBPP(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols,
28  const MAT &leftlowrankfactor, const MAT &rightlowrankfactor,
29  const MPICommunicator &communicator)
30  : DistNMF1D<INPUTMATTYPE>(Arows, Acols, leftlowrankfactor,
31  rightlowrankfactor, communicator) {
32  // allocate memory for matrices
33  HtH.zeros(this->m_k, this->m_k);
34  WtW.zeros(this->m_k, this->m_k);
35  AcolstW.zeros(this->globaln() / this->m_mpicomm.size(), this->m_k);
36  AcolstWt.zeros(this->m_k, this->globaln() / this->m_mpicomm.size());
37  ArowsH.zeros(this->globalm() / this->m_mpicomm.size(), this->m_k);
38  ArowsHt.zeros(this->m_k, this->globalm() / this->m_mpicomm.size());
39  localWnorm.zeros(this->m_k);
40  Wnorm.zeros(this->m_k);
41  PRINTROOT("NAIVEANLSBPP Constructor completed");
42  printConfig();
43  }
44  void computeNMF() {
45  PRINTROOT("starting computeNMF");
46 #ifdef MPI_VERBOSE
47  DISTPRINTINFO(PRINTMAT(this->m_Arows));
48  DISTPRINTINFO(PRINTMAT(this->m_Acols));
49 #endif
50  // we need only Acolst. So we are transposing and keeping it.
51  // Also for dense matrix, having duplicate copy is costly.
52 #ifndef BUILD_SPARSE
53  inplace_trans(this->m_Acols);
54  INPUTMATTYPE Acolst(this->m_Acols.memptr(), this->m_Acols.n_rows,
55  this->m_Acols.n_cols, false, true);
56 #else
57  INPUTMATTYPE Acolst = this->m_Acols.t();
58 #endif
59 
60  for (unsigned int iter = 0; iter < this->num_iterations(); iter++) {
61  if (iter > 0 && this->is_compute_error()) {
62  this->m_prevH = this->m_H;
63  this->m_prevHtH = this->HtH;
64  }
65  mpitic(); // total_d W&H
66  // update H given A,W
67  {
68  double tempTime = this->globalW();
69  this->time_stats.communication_duration(tempTime);
70  this->time_stats.allgather_duration(tempTime);
71  mpitic(); // gramW
72  WtW = this->m_globalW.t() * this->m_globalW;
73  tempTime = mpitoc(); // gramW
74  this->time_stats.compute_duration(tempTime);
75  this->time_stats.gram_duration(tempTime);
77 #ifdef MPI_VERBOSE
78  DISTPRINTINFO(PRINTMAT(this->m_W));
79  PRINTROOT(PRINTMAT(this->m_globalW));
80  DISTPRINTINFO(PRINTMAT(WtW));
81 #endif
82  tempTime = -1;
83  mpitic(); // mmH
84  AcolstW = Acolst * this->m_globalW;
85 // #if defined BUILD_SPARSE && MKL_FOUND
86 // ARMAMKLSCSCMM(this->m_Acols, 'T', this->m_globalWt,
87 // AcolstWt.memptr());
88 // #ifdef MPI_VERBOSE
89 // DISTPRINTINFO(PRINTMAT(AcolstWt));
90 // #endif
91 // AcolstW = reshape(AcolstWt.t(), this->globaln() /
92 // this->m_mpicomm.size(), this->m_k);
93 // #else
94 // AcolstW = Acolst * this->m_globalW;
95 // #endif
96 #ifdef MPI_VERBOSE
97  DISTPRINTINFO(PRINTMAT(AcolstW));
98 #endif
99  tempTime = mpitoc(); // mmH
100  PRINTROOT(PRINTMATINFO(this->m_Acols)
101  << PRINTMATINFO(this->m_globalW) << PRINTMATINFO(AcolstW));
102  this->time_stats.compute_duration(tempTime);
103  this->time_stats.mm_duration(tempTime);
104  this->reportTime(tempTime, "::AcolstW::");
105  mpitic(); // nnlsH
107  DISTPRINTINFO(PRINTMATINFO(AcolstW));
108  BPPNNLS<MAT, VEC> subProblem2(WtW, AcolstW, true);
109  subProblem2.solveNNLS();
110  this->m_Ht = subProblem2.getSolutionMatrix();
111  fixNumericalError<MAT>(&(this->m_Ht));
112  DISTPRINTINFO("OptimizeBlock::NNLS::" << PRINTMATINFO(this->m_Ht));
113 #ifdef MPI_VERBOSE
114  DISTPRINTINFO(PRINTMAT(this->m_Ht));
115 #endif
116  this->m_H = this->m_Ht.t();
117  tempTime = mpitoc(); // nnlsH
118  this->time_stats.compute_duration(tempTime);
119  this->time_stats.nnls_duration(tempTime);
120  this->reportTime(tempTime, "NNLS::H::");
121  }
122  // update W given A,H
123  {
124  double tempTime = this->globalH();
125  this->time_stats.communication_duration(tempTime);
126  this->time_stats.allgather_duration(tempTime);
127  mpitic(); // gramH
128  HtH = this->m_globalH.t() * this->m_globalH;
130 #ifdef MPI_VERBOSE
131  DISTPRINTINFO(PRINTMAT(this->m_H));
132  PRINTROOT(PRINTMAT(this->m_globalH));
133  DISTPRINTINFO(PRINTMAT(HtH));
134 #endif
135  tempTime = mpitoc(); // gramH
136  this->time_stats.compute_duration(tempTime);
137  this->time_stats.gram_duration(tempTime);
138  tempTime = -1;
139  mpitic(); // mmW
140  ArowsH = this->m_Arows * this->m_globalH;
141 // #if defined BUILD_SPARSE && MKL_FOUND
142 // ARMAMKLSCSCMM(this->m_Arows, 'N', this->m_globalHt,
143 // ArowsHt.memptr());
144 // #ifdef MPI_VERBOSE
145 // DISTPRINTINFO(PRINTMAT(ArowsHt));
146 // #endif
147 // ArowsH = ArowsHt.t();
148 // #else
149 // ArowsH = this->m_Arows * this->m_globalH;
150 // #endif
151 #ifdef MPI_VERBOSE
152  DISTPRINTINFO(PRINTMAT(ArowsH));
153 #endif
154  tempTime = mpitoc(); // mmW
155  PRINTROOT(PRINTMATINFO(this->m_Arows)
156  << PRINTMATINFO(this->m_globalH) << PRINTMATINFO(ArowsH));
157  this->time_stats.compute_duration(tempTime);
158  this->time_stats.mm_duration(tempTime);
159  this->reportTime(tempTime, "::ArowsH::");
160  mpitic(); // nnlsW
161  BPPNNLS<MAT, VEC> subProblem1(HtH, ArowsH, true);
162  subProblem1.solveNNLS();
163  this->m_Wt = subProblem1.getSolutionMatrix();
164  fixNumericalError<MAT>(&(this->m_Wt));
165  DISTPRINTINFO("OptimizeBlock::NNLS::" << PRINTMATINFO(this->m_Wt));
166 #ifdef MPI_VERBOSE
167  DISTPRINTINFO(PRINTMAT(this->m_Wt));
168 #endif
169  this->m_W = this->m_Wt.t();
170  tempTime = mpitoc(); // nnlsW
171  this->time_stats.compute_duration(tempTime);
172  this->time_stats.nnls_duration(tempTime);
173  this->reportTime(tempTime, "NNLS::W::");
174  }
175  this->time_stats.duration(mpitoc()); // total_d W&H
176  if (iter > 0 && this->is_compute_error()) {
177  this->computeError(WtW, this->m_prevHtH);
178  PRINTROOT("it=" << iter << "::algo::" << this->m_algorithm
179  << "::k::" << this->m_k
180  << "::err::" << this->m_objective_err << "::relerr::"
181  << this->m_objective_err / this->m_globalsqnormA);
182  }
183  PRINTROOT("completed it=" << iter
184  << "::taken::" << this->time_stats.duration());
185  }
186  MPI_Barrier(MPI_COMM_WORLD);
187  this->reportTime(this->time_stats.duration(), "total_d");
188  this->reportTime(this->time_stats.communication_duration(), "total_comm");
189  this->reportTime(this->time_stats.compute_duration(), "total_comp");
190  this->reportTime(this->time_stats.allgather_duration(), "total_allgather");
191  this->reportTime(this->time_stats.gram_duration(), "total_gram");
192  this->reportTime(this->time_stats.mm_duration(), "total_mm");
193  this->reportTime(this->time_stats.nnls_duration(), "total_nnls");
194  if (this->is_compute_error()) {
195  this->reportTime(this->time_stats.err_compute_duration(),
196  "total_err_compute");
197  this->reportTime(this->time_stats.err_compute_duration(),
198  "total_err_communication");
199  }
200  }
201 };
202 
203 } // namespace planc
204 
205 #endif // DISTNMF_NAIVE_ANLS_BPP_HPP_
MATTYPE getSolutionMatrix()
Definition: nnls.hpp:79
#define PRINTMAT(A)
Definition: utils.h:65
double mpitoc(int rank)
Definition: distutils.hpp:22
void reportTime(const double temp, const std::string &reportstring)
Definition: distnmf1D.hpp:175
const UWORD globaln() const
Definition: distnmf1D.hpp:169
const unsigned int num_iterations() const
Definition: distnmf1D.hpp:166
#define DISTPRINTINFO(MSG)
Definition: distutils.h:37
const bool is_compute_error() const
Definition: distnmf1D.hpp:173
void mpitic()
Definition: distutils.hpp:11
void computeError(const MAT &WtW, const MAT &HtH)
Definition: distnmf1D.hpp:122
int solveNNLS()
Definition: bppnnls.hpp:30
#define MAT
Definition: utils.h:52
double globalW()
Definition: distnmf1D.hpp:73
DistNaiveANLSBPP(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator)
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define PRINTMATINFO(A)
Definition: utils.h:63
const UWORD globalm() const
Definition: distnmf1D.hpp:168
double globalH()
Definition: distnmf1D.hpp:95