3 #ifndef DISTNMF_DISTNMF_HPP_ 4 #define DISTNMF_DISTNMF_HPP_ 15 template <
typename INPUTMATTYPE>
20 Pacoss_Communicator<double> *m_rowcomm;
21 Pacoss_Communicator<double> *m_colcomm;
27 double m_globalsqnormA;
43 DistNMF(
const INPUTMATTYPE &input,
const MAT &leftlowrankfactor,
45 :
NMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor),
46 m_mpicomm(communicator),
47 time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
48 double sqnorma = this->normA * this->normA;
51 MPI_Allreduce(&sqnorma, &(this->m_globalsqnormA), 1, MPI_DOUBLE, MPI_SUM,
53 this->m_ownedm = this->W.n_rows;
54 this->m_ownedn = this->H.n_rows;
57 this->m_globalm = this->W.n_rows * this->m_mpicomm.
size();
58 this->m_globaln = this->H.n_rows * this->m_mpicomm.
size();
60 MPI_Allreduce(&(this->m), &(this->m_globalm), 1, MPI_INT, MPI_SUM,
62 MPI_Allreduce(&(this->n), &(this->m_globaln), 1, MPI_INT, MPI_SUM,
66 INFO <<
"globalsqnorma::" << this->m_globalsqnormA
67 <<
"::globalm::" << this->m_globalm
68 <<
"::globaln::" << this->m_globaln << std::endl;
70 this->m_compute_error = 0;
71 localWnorm.zeros(this->k);
76 void set_rowcomm(Pacoss_Communicator<double> *rowcomm) {
77 this->m_rowcomm = rowcomm;
79 void set_colcomm(Pacoss_Communicator<double> *colcomm) {
80 this->m_colcomm = colcomm;
83 const int globalm()
const {
return m_globalm; }
86 const int globaln()
const {
return m_globaln; }
96 void reportTime(
const double temp,
const std::string &reportstring) {
97 double mintemp, maxtemp, sumtemp;
98 MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
99 MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
100 MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
101 PRINTROOT(reportstring <<
"::m::" << this->m_globalm
102 <<
"::n::" << this->m_globaln <<
"::k::" << this->k
104 <<
"::algo::" << this->m_algorithm
105 <<
"::root::" << temp <<
"::min::" << mintemp
106 <<
"::avg::" << (sumtemp) / (
MPI_SIZE)
107 <<
"::max::" << maxtemp);
111 localWnorm = sum(this->W % this->W);
113 MPI_Allreduce(localWnorm.memptr(), Wnorm.memptr(), this->k, MPI_DOUBLE,
114 MPI_SUM, MPI_COMM_WORLD);
117 for (
int i = 0; i < this->k; i++) {
119 double norm_const = sqrt(Wnorm(i));
120 this->W.col(i) = this->W.col(i) / norm_const;
121 this->H.col(i) = norm_const * this->H.col(i);
129 #endif // DISTNMF_DISTNMF_HPP_ const int globaln() const
returns globaln
void normalize_by_W()
Column Normalizes the distributed W matrix.
const bool is_compute_error() const
returns the flag to compute error or not.
void compute_error(const uint &ce)
return the current error
const int globalm() const
returns globalm
const double allgather_duration() const
void reportTime(const double temp, const std::string &reportstring)
Reports the time.
void algorithm(algotype dat)
returns the NMF algorithm
DistNMF(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator)
There are totally prxpc process.
const int size() const
returns the total number of mpi processes
const double globalsqnorma() const
returns global squared norm of A
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
const MPI_Comm * commSubs() const