3 #ifndef DISTNMF_NAIVE_ANLS_BPP_HPP_ 4 #define DISTNMF_NAIVE_ANLS_BPP_HPP_ 11 template <
class INPUTMATTYPE>
15 MAT ArowsHt, AcolstWt;
20 PRINTROOT(
"NAIVEANLSBPP constructor completed::" 21 <<
"::A::" << this->m_globalm <<
"x" << this->m_globaln
22 <<
"::norm::" << this->m_globalsqnormA <<
"::k::" << this->m_k
23 <<
"::it::" << this->m_num_iterations);
28 const MAT &leftlowrankfactor,
const MAT &rightlowrankfactor,
30 :
DistNMF1D<INPUTMATTYPE>(Arows, Acols, leftlowrankfactor,
31 rightlowrankfactor, communicator) {
33 HtH.zeros(this->m_k, this->m_k);
34 WtW.zeros(this->m_k, this->m_k);
35 AcolstW.zeros(this->
globaln() / this->m_mpicomm.size(), this->m_k);
36 AcolstWt.zeros(this->m_k, this->
globaln() / this->m_mpicomm.size());
37 ArowsH.zeros(this->
globalm() / this->m_mpicomm.size(), this->m_k);
38 ArowsHt.zeros(this->m_k, this->
globalm() / this->m_mpicomm.size());
39 localWnorm.zeros(this->m_k);
40 Wnorm.zeros(this->m_k);
41 PRINTROOT(
"NAIVEANLSBPP Constructor completed");
53 inplace_trans(this->m_Acols);
54 INPUTMATTYPE Acolst(this->m_Acols.memptr(), this->m_Acols.n_rows,
55 this->m_Acols.n_cols,
false,
true);
57 INPUTMATTYPE Acolst = this->m_Acols.t();
60 for (
unsigned int iter = 0; iter < this->
num_iterations(); iter++) {
62 this->m_prevH = this->m_H;
63 this->m_prevHtH = this->HtH;
68 double tempTime = this->
globalW();
69 this->time_stats.communication_duration(tempTime);
70 this->time_stats.allgather_duration(tempTime);
72 WtW = this->m_globalW.t() * this->m_globalW;
74 this->time_stats.compute_duration(tempTime);
75 this->time_stats.gram_duration(tempTime);
84 AcolstW = Acolst * this->m_globalW;
102 this->time_stats.compute_duration(tempTime);
103 this->time_stats.mm_duration(tempTime);
111 fixNumericalError<MAT>(&(this->m_Ht));
116 this->m_H = this->m_Ht.t();
118 this->time_stats.compute_duration(tempTime);
119 this->time_stats.nnls_duration(tempTime);
124 double tempTime = this->
globalH();
125 this->time_stats.communication_duration(tempTime);
126 this->time_stats.allgather_duration(tempTime);
128 HtH = this->m_globalH.t() * this->m_globalH;
136 this->time_stats.compute_duration(tempTime);
137 this->time_stats.gram_duration(tempTime);
140 ArowsH = this->m_Arows * this->m_globalH;
157 this->time_stats.compute_duration(tempTime);
158 this->time_stats.mm_duration(tempTime);
164 fixNumericalError<MAT>(&(this->m_Wt));
169 this->m_W = this->m_Wt.t();
171 this->time_stats.compute_duration(tempTime);
172 this->time_stats.nnls_duration(tempTime);
175 this->time_stats.duration(
mpitoc());
178 PRINTROOT(
"it=" << iter <<
"::algo::" << this->m_algorithm
179 <<
"::k::" << this->m_k
180 <<
"::err::" << this->m_objective_err <<
"::relerr::" 181 << this->m_objective_err / this->m_globalsqnormA);
184 <<
"::taken::" << this->time_stats.duration());
186 MPI_Barrier(MPI_COMM_WORLD);
187 this->
reportTime(this->time_stats.duration(),
"total_d");
188 this->
reportTime(this->time_stats.communication_duration(),
"total_comm");
189 this->
reportTime(this->time_stats.compute_duration(),
"total_comp");
190 this->
reportTime(this->time_stats.allgather_duration(),
"total_allgather");
191 this->
reportTime(this->time_stats.gram_duration(),
"total_gram");
192 this->
reportTime(this->time_stats.mm_duration(),
"total_mm");
193 this->
reportTime(this->time_stats.nnls_duration(),
"total_nnls");
195 this->
reportTime(this->time_stats.err_compute_duration(),
196 "total_err_compute");
197 this->
reportTime(this->time_stats.err_compute_duration(),
198 "total_err_communication");
205 #endif // DISTNMF_NAIVE_ANLS_BPP_HPP_ MATTYPE getSolutionMatrix()
void reportTime(const double temp, const std::string &reportstring)
const UWORD globaln() const
const unsigned int num_iterations() const
#define DISTPRINTINFO(MSG)
const bool is_compute_error() const
void computeError(const MAT &WtW, const MAT &HtH)
DistNaiveANLSBPP(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator)
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
const UWORD globalm() const