planc
Parallel Lowrank Approximation with Non-negativity Constraints
distnmf.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_DISTNMF_HPP_
4 #define DISTNMF_DISTNMF_HPP_
5 
6 #include <string>
7 #include "common/nmf.hpp"
8 #include "distnmf/mpicomm.hpp"
9 #include "distnmftime.hpp"
10 #ifdef USE_PACOSS
11 #include "pacoss.h"
12 #endif
13 
14 namespace planc {
15 template <typename INPUTMATTYPE>
16 class DistNMF : public NMF<INPUTMATTYPE> {
17  protected:
18  const MPICommunicator &m_mpicomm;
19 #ifdef USE_PACOSS
20  Pacoss_Communicator<double> *m_rowcomm;
21  Pacoss_Communicator<double> *m_colcomm;
22 #endif
23  UWORD m_ownedm;
24  UWORD m_ownedn;
25  UWORD m_globalm;
26  UWORD m_globaln;
27  double m_globalsqnormA;
28  DistNMFTime time_stats;
29  uint m_compute_error;
30  algotype m_algorithm;
31  ROWVEC localWnorm;
32  ROWVEC Wnorm;
33 
34  public:
43  DistNMF(const INPUTMATTYPE &input, const MAT &leftlowrankfactor,
44  const MAT &rightlowrankfactor, const MPICommunicator &communicator)
45  : NMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor),
46  m_mpicomm(communicator),
47  time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
48  double sqnorma = this->normA * this->normA;
49  this->m_globalm = 0;
50  this->m_globaln = 0;
51  MPI_Allreduce(&sqnorma, &(this->m_globalsqnormA), 1, MPI_DOUBLE, MPI_SUM,
52  MPI_COMM_WORLD);
53  this->m_ownedm = this->W.n_rows;
54  this->m_ownedn = this->H.n_rows;
55 #ifdef USE_PACOSS
56  // TODO(kayaogz): This is a hack for now. Talk to Ramki.
57  this->m_globalm = this->W.n_rows * this->m_mpicomm.size();
58  this->m_globaln = this->H.n_rows * this->m_mpicomm.size();
59 #else
60  MPI_Allreduce(&(this->m), &(this->m_globalm), 1, MPI_INT, MPI_SUM,
61  this->m_mpicomm.commSubs()[0]);
62  MPI_Allreduce(&(this->n), &(this->m_globaln), 1, MPI_INT, MPI_SUM,
63  this->m_mpicomm.commSubs()[1]);
64 #endif
65  if (ISROOT) {
66  INFO << "globalsqnorma::" << this->m_globalsqnormA
67  << "::globalm::" << this->m_globalm
68  << "::globaln::" << this->m_globaln << std::endl;
69  }
70  this->m_compute_error = 0;
71  localWnorm.zeros(this->k);
72  Wnorm.zeros(this->k);
73  }
74 
75 #ifdef USE_PACOSS
76  void set_rowcomm(Pacoss_Communicator<double> *rowcomm) {
77  this->m_rowcomm = rowcomm;
78  }
79  void set_colcomm(Pacoss_Communicator<double> *colcomm) {
80  this->m_colcomm = colcomm;
81  }
82 #endif
83  const int globalm() const { return m_globalm; }
86  const int globaln() const { return m_globaln; }
88  const double globalsqnorma() const { return m_globalsqnormA; }
90  void compute_error(const uint &ce) { this->m_compute_error = ce; }
92  const bool is_compute_error() const { return (this->m_compute_error); }
94  void algorithm(algotype dat) { this->m_algorithm = dat; }
96  void reportTime(const double temp, const std::string &reportstring) {
97  double mintemp, maxtemp, sumtemp;
98  MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
99  MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
100  MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
101  PRINTROOT(reportstring << "::m::" << this->m_globalm
102  << "::n::" << this->m_globaln << "::k::" << this->k
103  << "::SIZE::" << MPI_SIZE
104  << "::algo::" << this->m_algorithm
105  << "::root::" << temp << "::min::" << mintemp
106  << "::avg::" << (sumtemp) / (MPI_SIZE)
107  << "::max::" << maxtemp);
108  }
110  void normalize_by_W() {
111  localWnorm = sum(this->W % this->W);
112  mpitic();
113  MPI_Allreduce(localWnorm.memptr(), Wnorm.memptr(), this->k, MPI_DOUBLE,
114  MPI_SUM, MPI_COMM_WORLD);
115  double temp = mpitoc();
116  this->time_stats.allgather_duration(temp);
117  for (int i = 0; i < this->k; i++) {
118  if (Wnorm(i) > 1) {
119  double norm_const = sqrt(Wnorm(i));
120  this->W.col(i) = this->W.col(i) / norm_const;
121  this->H.col(i) = norm_const * this->H.col(i);
122  }
123  }
124  }
125 };
126 
127 } // namespace planc
128 
129 #endif // DISTNMF_DISTNMF_HPP_
const int globaln() const
returns globaln
Definition: distnmf.hpp:86
void normalize_by_W()
Column Normalizes the distributed W matrix.
Definition: distnmf.hpp:110
double mpitoc(int rank)
Definition: distutils.hpp:22
const bool is_compute_error() const
returns the flag to compute error or not.
Definition: distnmf.hpp:92
void compute_error(const uint &ce)
return the current error
Definition: distnmf.hpp:90
algotype
Definition: utils.h:10
const int globalm() const
returns globalm
Definition: distnmf.hpp:84
const double allgather_duration() const
Definition: distnmftime.hpp:65
void reportTime(const double temp, const std::string &reportstring)
Reports the time.
Definition: distnmf.hpp:96
void algorithm(algotype dat)
returns the NMF algorithm
Definition: distnmf.hpp:94
void mpitic()
Definition: distutils.hpp:11
#define ISROOT
Definition: distutils.h:14
#define INFO
Definition: utils.h:36
#define UWORD
Definition: utils.h:60
DistNMF(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator)
There are totally prxpc process.
Definition: distnmf.hpp:43
const int size() const
returns the total number of mpi processes
Definition: mpicomm.hpp:120
#define MAT
Definition: utils.h:52
const double globalsqnorma() const
returns global squared norm of A
Definition: distnmf.hpp:88
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define MPI_SIZE
Definition: distutils.h:15
const MPI_Comm * commSubs() const
Definition: mpicomm.hpp:129