3 #ifndef DISTNMF_DISTNMF1D_HPP_ 4 #define DISTNMF_DISTNMF1D_HPP_ 13 template <
class INPUTMATTYPE>
19 UWORD m_globalm, m_globaln;
22 MAT m_globalW, m_globalH;
23 MAT m_globalWt, m_globalHt;
24 double m_objective_err;
25 double m_globalsqnormA;
26 unsigned int m_num_iterations;
40 DistNMF1D(
const INPUTMATTYPE &Arows,
const INPUTMATTYPE &Acols,
41 const MAT &leftlowrankfactor,
const MAT &rightlowrankfactor,
46 m_W(leftlowrankfactor),
47 m_H(rightlowrankfactor),
48 time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
49 this->m_globalm = Arows.n_rows *
MPI_SIZE;
50 this->m_globaln = Arows.n_cols;
51 this->m_Wt = this->m_W.t();
52 this->m_Ht = this->m_H.t();
53 this->m_k = this->m_W.n_cols;
54 this->m_num_iterations = 20;
55 m_globalW.zeros(this->m_globalm, this->m_k);
56 err_matrix.zeros(this->m_globalm, this->m_k);
57 m_globalWt.zeros(this->m_k, this->m_globalm);
58 m_globalH.zeros(this->m_globaln, this->m_k);
59 m_globalHt.zeros(this->m_k, this->m_globaln);
60 HAtW.zeros(this->m_k, this->m_k);
61 globalHAtW.zeros(this->m_k, this->m_k);
62 double localsqnormA = norm(Arows,
"fro");
63 localsqnormA = localsqnormA * localsqnormA;
64 MPI_Allreduce(&localsqnormA, &(this->m_globalsqnormA), 1, MPI_DOUBLE,
65 MPI_SUM, MPI_COMM_WORLD);
67 <<
"::globalm::" << this->m_globalm
68 <<
"::globaln::" << this->m_globaln);
74 int sendcnt = this->m_W.n_rows * this->m_W.n_cols;
75 int recvcnt = this->m_W.n_rows * this->m_W.n_cols;
76 this->m_Wt = this->m_W.t();
78 MPI_Allgather(this->m_Wt.memptr(), sendcnt, MPI_DOUBLE,
79 this->m_globalWt.memptr(), recvcnt, MPI_DOUBLE,
87 double commTime =
mpitoc();
89 this->m_globalW = this->m_globalWt.t();
96 int sendcnt = this->m_H.n_rows * this->m_H.n_cols;
97 int recvcnt = this->m_H.n_rows * this->m_H.n_cols;
98 this->m_Ht = this->m_H.t();
100 MPI_Allgather(this->m_Ht.memptr(), sendcnt, MPI_DOUBLE,
101 this->m_globalHt.memptr(), recvcnt, MPI_DOUBLE,
109 double commTime =
mpitoc();
110 this->m_globalH = this->m_globalHt.t();
124 if (this->m_Acols.n_rows == this->m_globalm) {
125 HAtW = this->m_prevH.t() * (this->m_Acols.t() * this->m_globalW);
129 HAtW = this->m_prevH.t() * (this->m_Acols * this->m_globalW);
134 MPI_Allreduce(HAtW.memptr(), globalHAtW.memptr(), this->m_k * this->m_k,
135 MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
139 double tHAtW = trace(globalHAtW);
140 double tWtWHtH = trace(WtW * HtH);
141 PRINTROOT(
"normA::" << this->m_globalsqnormA <<
"::tHAtW::" << 2 * tHAtW
142 <<
"::tWtWHtH::" << tWtWHtH);
143 this->m_objective_err = this->m_globalsqnormA - 2 * tHAtW + tWtWHtH;
175 void reportTime(
const double temp,
const std::string &reportstring) {
176 double mintemp, maxtemp, sumtemp;
177 MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
178 MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
179 MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
180 PRINTROOT(reportstring <<
"::m::" << this->m_globalm
181 <<
"::n::" << this->m_globaln <<
"::k::" << this->m_k
183 <<
"::algo::" << this->m_algorithm
184 <<
"::root::" << temp <<
"::min::" << mintemp
185 <<
"::avg::" << (sumtemp) / (
MPI_SIZE)
186 <<
"::max::" << maxtemp);
192 #endif // DISTNMF_DISTNMF1D_HPP_ const double err_communication_duration() const
virtual void computeNMF()=0
const double err_compute_duration() const
void algorithm(algotype dat)
MAT getRightLowRankFactor()
void reportTime(const double temp, const std::string &reportstring)
const UWORD globaln() const
const unsigned int num_iterations() const
#define DISTPRINTINFO(MSG)
const bool is_compute_error() const
void computeError(const MAT &WtW, const MAT &HtH)
MAT getLeftLowRankFactor()
void num_iterations(int it)
void compute_error(const uint &ce)
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
const UWORD globalm() const
DistNMF1D(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &mpicomm)