3 #ifndef DISTNMF_AUNMF_HPP_ 4 #define DISTNMF_AUNMF_HPP_ 25 template <
class INPUTMATTYPE>
37 virtual void updateW() = 0;
38 virtual void updateH() = 0;
65 std::vector<int> recvWtAsize;
66 std::vector<int> recvAHsize;
74 void allocateMatrices() {
78 <<
"::localm::" << this->m <<
"::localn::" << this->n
79 <<
"::globalm::" << this->
globalm() <<
"::globaln::" 81 HtH.zeros(this->k, this->k);
82 localHtH.zeros(this->k, this->k);
83 Hj.zeros(this->n, this->perk);
84 Hjt.zeros(this->perk, this->n);
85 AijHj.zeros(this->m, this->perk);
86 AijHjt.zeros(this->perk, this->m);
90 fillVector<int>(fillsize, &recvAHsize);
93 INFO <<
"::recvAHsize::";
94 printVector<int>(recvAHsize);
103 WtW.zeros(this->k, this->k);
104 localWtW.zeros(this->k, this->k);
105 Wi.zeros(this->m, this->perk);
106 Wit.zeros(this->perk, this->m);
107 WitAij.zeros(this->perk, this->n);
108 AijWit.zeros(this->n, this->perk);
112 fillVector<int>(fillsize, &recvWtAsize);
119 INFO <<
"::recvWtAsize::";
120 printVector<int>(recvWtAsize);
125 errMtx.zeros(this->m, this->n);
126 A_errMtx.zeros(this->m, this->n);
131 void freeMatrices() {
178 :
DistNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
180 num_k_blocks = numkblks;
181 perk = this->k / num_k_blocks;
183 this->Wt = leftlowrankfactor.t();
184 this->Ht = rightlowrankfactor.t();
186 PRINTROOT(
"aunmf()::constructor succesful");
206 for (
int i = 0; i < num_k_blocks; i++) {
207 int start_row = i * perk;
208 int end_row = (i + 1) * perk - 1;
209 Wt_blk = Wt.rows(start_row, end_row);
211 WtAij.rows(start_row, end_row) = WtAij_blk;
217 memcpy(Wit.memptr(), Wt_blk.memptr(),
218 Wt_blk.n_rows * Wt_blk.n_cols *
sizeof(Wt_blk[0]));
220 this->m_rowcomm->expCommBegin(Wit.memptr(), this->perk);
221 this->m_rowcomm->expCommFinish(Wit.memptr(), this->perk);
227 MPI_Allgather(Wt_blk.memptr(), sendcnt, MPI_DOUBLE, Wit.memptr(), recvcnt,
228 MPI_DOUBLE, this->m_mpicomm.commSubs()[1]);
237 this->time_stats.communication_duration(temp);
238 this->time_stats.allgather_duration(temp);
240 this->WitAij = this->Wit * this->A;
259 this->time_stats.compute_duration(temp);
260 this->time_stats.mm_duration(temp);
265 this->m_colcomm->foldCommBegin(WitAij.memptr(), this->perk);
266 this->m_colcomm->foldCommFinish(WitAij.memptr(), this->perk);
268 memcpy(WtAij_blk.memptr(), WitAij.memptr(),
269 WtAij_blk.n_rows * WtAij_blk.n_cols *
sizeof(WtAij_blk[0]));
273 MPI_Reduce_scatter(this->WitAij.memptr(), this->WtAij_blk.memptr(),
274 &(this->recvWtAsize[0]), MPI_DOUBLE, MPI_SUM,
275 this->m_mpicomm.commSubs()[0]);
278 this->time_stats.communication_duration(temp);
279 this->time_stats.reducescatter_duration(temp);
293 for (
int i = 0; i < num_k_blocks; i++) {
294 int start_row = i * perk;
295 int end_row = (i + 1) * perk - 1;
296 Ht_blk = Ht.rows(start_row, end_row);
298 AHtij.rows(start_row, end_row) = AHtij_blk;
311 memcpy(Hjt.memptr(), Ht_blk.memptr(),
312 Ht_blk.n_rows * Ht_blk.n_cols *
sizeof(Ht_blk[0]));
314 this->m_colcomm->expCommBegin(Hjt.memptr(), this->perk);
315 this->m_colcomm->expCommFinish(Hjt.memptr(), this->perk);
321 MPI_Allgather(this->Ht_blk.memptr(), sendcnt, MPI_DOUBLE,
322 this->Hjt.memptr(), recvcnt, MPI_DOUBLE,
323 this->m_mpicomm.commSubs()[0]);
333 this->time_stats.communication_duration(temp);
334 this->time_stats.allgather_duration(temp);
336 this->AijHjt = this->Hjt * this->A_ij_t;
355 this->time_stats.compute_duration(temp);
356 this->time_stats.mm_duration(temp);
361 this->m_rowcomm->foldCommBegin(AijHjt.memptr(), this->perk);
362 this->m_rowcomm->foldCommFinish(AijHjt.memptr(), this->perk);
364 memcpy(AHtij_blk.memptr(), AijHjt.memptr(),
365 AHtij_blk.n_rows * AHtij_blk.n_cols *
sizeof(AHtij_blk[0]));
369 MPI_Reduce_scatter(this->AijHjt.memptr(), this->AHtij_blk.memptr(),
370 &(this->recvAHsize[0]), MPI_DOUBLE, MPI_SUM,
371 this->m_mpicomm.commSubs()[1]);
374 this->time_stats.communication_duration(temp);
375 this->time_stats.reducescatter_duration(temp);
389 localWtW = X.t() * X;
392 <<
"::localWtW::" << norm(this->localWtW,
"fro"));
395 this->time_stats.compute_duration(temp);
396 this->time_stats.gram_duration(temp);
398 if (X.n_rows == this->m) {
404 MPI_Allreduce(localWtW.memptr(), (*XtX).memptr(), this->k * this->k,
405 MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
407 this->time_stats.communication_duration(temp);
408 this->time_stats.allreduce_duration(temp);
423 prevH.zeros(size(this->H));
424 prevHtH.zeros(this->k, this->k);
425 WtAijH.zeros(this->k, this->k);
426 localWtAijH.zeros(this->k, this->k);
428 #ifdef __WITH__BARRIER__TIMING__ 429 MPI_Barrier(MPI_COMM_WORLD);
431 for (
unsigned int iter = 0; iter < this->
num_iterations(); iter++) {
434 this->prevH = this->H;
435 this->prevHtH = this->HtH;
443 this->applyReg(this->
regH(), &this->WtW);
461 this->time_stats.compute_duration(temp);
462 this->time_stats.nnls_duration(temp);
470 this->applyReg(this->
regW(), &this->HtH);
489 this->time_stats.compute_duration(temp);
490 this->time_stats.nnls_duration(temp);
493 this->time_stats.duration(
MPITOC);
501 PRINTROOT(
"it=" << iter <<
"::algo::" << this->m_algorithm <<
"::k::" 502 << this->k <<
"::err::" << sqrt(this->objective_err)
504 << sqrt(this->objective_err / this->m_globalsqnormA));
507 <<
"::taken::" << this->time_stats.duration());
509 MPI_Barrier(MPI_COMM_WORLD);
510 this->
reportTime(this->time_stats.duration(),
"total_d");
511 this->
reportTime(this->time_stats.communication_duration(),
"total_comm");
512 this->
reportTime(this->time_stats.compute_duration(),
"total_comp");
513 this->
reportTime(this->time_stats.allgather_duration(),
"total_allgather");
514 this->
reportTime(this->time_stats.allreduce_duration(),
"total_allreduce");
515 this->
reportTime(this->time_stats.reducescatter_duration(),
516 "total_reducescatter");
517 this->
reportTime(this->time_stats.gram_duration(),
"total_gram");
518 this->
reportTime(this->time_stats.mm_duration(),
"total_mm");
519 this->
reportTime(this->time_stats.nnls_duration(),
"total_nnls");
521 this->
reportTime(this->time_stats.err_compute_duration(),
522 "total_err_compute");
523 this->
reportTime(this->time_stats.err_compute_duration(),
524 "total_err_communication");
542 this->localWtAijH = this->WtAij * this->prevH;
551 this->time_stats.err_compute_duration(temp);
553 MPI_Allreduce(this->localWtAijH.memptr(), this->WtAijH.memptr(),
554 this->k * this->k, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
559 this->time_stats.err_communication_duration(temp);
560 double tWtAijh = trace(this->WtAijH);
561 double tWtWHtH = trace(this->WtW * this->prevHtH);
562 PRINTROOT(
"::it=" << it <<
"normA::" << this->m_globalsqnormA
563 <<
"::tWtAijH::" << 2 * tWtAijh
564 <<
"::tWtWHtH::" << tWtWHtH);
565 this->objective_err = this->m_globalsqnormA - 2 * tWtAijh + tWtWHtH;
571 double local_sqerror = 0.0;
572 PRINTROOT(
"::it=" << it <<
"::Calling compute error 2");
576 this->Wi = this->Wit.t();
577 errMtx = this->Wi * this->Hjt;
578 A_errMtx = this->A - errMtx;
579 local_sqerror = norm(A_errMtx,
"fro");
580 local_sqerror *= local_sqerror;
582 this->time_stats.err_compute_duration(temp);
585 MPI_Allreduce(&local_sqerror, &this->objective_err, 1, MPI_DOUBLE, MPI_SUM,
588 this->time_stats.err_communication_duration(temp);
594 #endif // DISTNMF_AUNMF_HPP_
const int globaln() const
returns globaln
void computeError(const int it)
We assume this error function will be called in every iteration before updating the block to compute ...
const unsigned int num_iterations() const
Returns the number of iterations.
const bool is_compute_error() const
returns the flag to compute error or not.
void computeError2(const int it)
const int globalm() const
returns globalm
void reportTime(const double temp, const std::string &reportstring)
Reports the time.
#define DISTPRINTINFO(MSG)
void distWtA()
This is a matrix multiplication routine based on reduce_scatter.
void computeNMF()
This is the main loop function Refer Algorithm 1 in Page 3 of the PPoPP HPC-NMF paper.
FVEC regH()
Returns the L2 and L1 regularization parameters of W as a vector.
FVEC regW()
Returns the L2 and L1 regularization parameters of W as a vector.
void distAH()
There are totally prxpc process.
DistAUNMF(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
Public constructor with local input matrix, local factors and communicator.
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
void distInnerProduct(const MAT &X, MAT *XtX)
There are p processes.