planc
Parallel Lowrank Approximation with Non-negativity Constraints
distnmf1D.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_DISTNMF1D_HPP_
4 #define DISTNMF_DISTNMF1D_HPP_
5 
6 #include <string>
7 #include "common/distutils.hpp"
8 #include "common/utils.h"
9 #include "common/utils.hpp"
10 
11 namespace planc {
12 
13 template <class INPUTMATTYPE>
14 class DistNMF1D {
15  protected:
16  const MPICommunicator &m_mpicomm;
17  INPUTMATTYPE m_Arows;
18  INPUTMATTYPE m_Acols;
19  UWORD m_globalm, m_globaln;
20  MAT m_W, m_H;
21  MAT m_Wt, m_Ht;
22  MAT m_globalW, m_globalH;
23  MAT m_globalWt, m_globalHt;
24  double m_objective_err;
25  double m_globalsqnormA;
26  unsigned int m_num_iterations;
27  unsigned int m_k; // low rank k
28  DistNMFTime time_stats;
29  MAT m_prevH; // this is needed for error computation
30  MAT m_prevHtH; // this is needed for error computation
31  uint m_compute_error;
32  algotype m_algorithm;
33 
34  private:
35  MAT HAtW; // needed for error computation
36  MAT globalHAtW; // needed for error computation
37  MAT err_matrix; // needed for error computation.
38 
39  public:
40  DistNMF1D(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols,
41  const MAT &leftlowrankfactor, const MAT &rightlowrankfactor,
42  const MPICommunicator &mpicomm)
43  : m_mpicomm(mpicomm),
44  m_Arows(Arows),
45  m_Acols(Acols),
46  m_W(leftlowrankfactor),
47  m_H(rightlowrankfactor),
48  time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
49  this->m_globalm = Arows.n_rows * MPI_SIZE;
50  this->m_globaln = Arows.n_cols;
51  this->m_Wt = this->m_W.t();
52  this->m_Ht = this->m_H.t();
53  this->m_k = this->m_W.n_cols;
54  this->m_num_iterations = 20;
55  m_globalW.zeros(this->m_globalm, this->m_k);
56  err_matrix.zeros(this->m_globalm, this->m_k);
57  m_globalWt.zeros(this->m_k, this->m_globalm);
58  m_globalH.zeros(this->m_globaln, this->m_k);
59  m_globalHt.zeros(this->m_k, this->m_globaln);
60  HAtW.zeros(this->m_k, this->m_k);
61  globalHAtW.zeros(this->m_k, this->m_k);
62  double localsqnormA = norm(Arows, "fro");
63  localsqnormA = localsqnormA * localsqnormA;
64  MPI_Allreduce(&localsqnormA, &(this->m_globalsqnormA), 1, MPI_DOUBLE,
65  MPI_SUM, MPI_COMM_WORLD);
66  DISTPRINTINFO("DistNMF1D constructor completed"
67  << "::globalm::" << this->m_globalm
68  << "::globaln::" << this->m_globaln);
69  }
70  /*
71  * Return the communication time
72  */
73  double globalW() {
74  int sendcnt = this->m_W.n_rows * this->m_W.n_cols;
75  int recvcnt = this->m_W.n_rows * this->m_W.n_cols;
76  this->m_Wt = this->m_W.t();
77  mpitic();
78  MPI_Allgather(this->m_Wt.memptr(), sendcnt, MPI_DOUBLE,
79  this->m_globalWt.memptr(), recvcnt, MPI_DOUBLE,
80  MPI_COMM_WORLD);
81  /*MPI_Gather(this->m_Wt.memptr(), sendcnt, MPI_DOUBLE,
82  this->m_globalWt.memptr(), recvcnt, MPI_DOUBLE,
83  0, MPI_COMM_WORLD);
84  sendcnt = this->m_globalWt.n_rows * this->m_globalWt.n_cols;
85  MPI_Bcast(this->m_globalWt.memptr(), sendcnt, MPI_DOUBLE, 0,
86  MPI_COMM_WORLD);*/
87  double commTime = mpitoc();
88  DISTPRINTINFO(PRINTMATINFO(this->m_Wt) << PRINTMATINFO(this->m_globalWt));
89  this->m_globalW = this->m_globalWt.t();
90  return commTime;
91  }
92  /*
93  * Return the communication time
94  */
95  double globalH() {
96  int sendcnt = this->m_H.n_rows * this->m_H.n_cols;
97  int recvcnt = this->m_H.n_rows * this->m_H.n_cols;
98  this->m_Ht = this->m_H.t();
99  mpitic();
100  MPI_Allgather(this->m_Ht.memptr(), sendcnt, MPI_DOUBLE,
101  this->m_globalHt.memptr(), recvcnt, MPI_DOUBLE,
102  MPI_COMM_WORLD);
103  /*MPI_Gather(this->m_Ht.memptr(), sendcnt, MPI_DOUBLE,
104  this->m_globalHt.memptr(), recvcnt, MPI_DOUBLE,
105  0, MPI_COMM_WORLD);
106  sendcnt = this->m_globalHt.n_rows * this->m_globalHt.n_cols;
107  MPI_Bcast(this->m_globalHt.memptr(), sendcnt, MPI_DOUBLE, 0,
108  MPI_COMM_WORLD);*/
109  double commTime = mpitoc();
110  this->m_globalH = this->m_globalHt.t();
111  DISTPRINTINFO(PRINTMATINFO(this->m_Ht) << PRINTMATINFO(this->m_globalHt));
112  return commTime;
113  }
114  /*
115  * Assuming you have the latest globalW and globalH.
116  * If other wise, call globalW and globalH before calling
117  * this function.
118  * (init.norm_A)^2 - 2*trace(H*(A'*W))+trace((W'*W)*(H*H'))
119  * each process owns globalsqnormA will have (init.norm_A)^2
120  *
121  */
122  void computeError(const MAT &WtW, const MAT &HtH) {
123  mpitic();
124  if (this->m_Acols.n_rows == this->m_globalm) {
125  HAtW = this->m_prevH.t() * (this->m_Acols.t() * this->m_globalW);
126  } else {
127  // we assume m_Acols would have been transposed
128  // by the derived classes.
129  HAtW = this->m_prevH.t() * (this->m_Acols * this->m_globalW);
130  }
131  double temp = mpitoc();
132  this->time_stats.err_compute_duration(temp);
133  mpitic();
134  MPI_Allreduce(HAtW.memptr(), globalHAtW.memptr(), this->m_k * this->m_k,
135  MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
136  temp = mpitoc();
137  this->time_stats.err_communication_duration(temp);
138  mpitic();
139  double tHAtW = trace(globalHAtW);
140  double tWtWHtH = trace(WtW * HtH);
141  PRINTROOT("normA::" << this->m_globalsqnormA << "::tHAtW::" << 2 * tHAtW
142  << "::tWtWHtH::" << tWtWHtH);
143  this->m_objective_err = this->m_globalsqnormA - 2 * tHAtW + tWtWHtH;
144  mpitoc();
145  this->time_stats.err_compute_duration(temp);
146  }
147  /*void computeError(const int it) {
148  mpitic();
149  err_matrix = this->m_globalW * this->m_prevH.t();
150  err_matrix = this->m_Acols - err_matrix;
151  PRINTROOT(PRINTMATINFO(this->m_globalW));
152  PRINTROOT(PRINTMATINFO(this->m_prevH));
153  PRINTROOT(PRINTMATINFO(err_matrix));
154  double error = norm(err_matrix, "fro");
155  error *= error;
156  double temp = mpitoc();
157  this->time_stats.err_compute_duration(temp);
158  mpitic();
159  MPI_Allreduce(&error, &(this->m_objective_err), 1,
160  MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
161  temp = mpitoc();
162  this->time_stats.err_communication_duration(temp);
163  }*/
164 
165  virtual void computeNMF() = 0;
166  const unsigned int num_iterations() const { return this->m_num_iterations; }
167  void num_iterations(int it) { m_num_iterations = it; }
168  const UWORD globalm() const { return m_globalm; }
169  const UWORD globaln() const { return m_globaln; }
170  MAT getLeftLowRankFactor() { return this->m_W; }
171  MAT getRightLowRankFactor() { return this->m_H; }
172  void compute_error(const uint &ce) { this->m_compute_error = ce; }
173  const bool is_compute_error() const { return (this->m_compute_error); }
174  void algorithm(algotype dat) { this->m_algorithm = dat; }
175  void reportTime(const double temp, const std::string &reportstring) {
176  double mintemp, maxtemp, sumtemp;
177  MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
178  MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
179  MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
180  PRINTROOT(reportstring << "::m::" << this->m_globalm
181  << "::n::" << this->m_globaln << "::k::" << this->m_k
182  << "::SIZE::" << MPI_SIZE
183  << "::algo::" << this->m_algorithm
184  << "::root::" << temp << "::min::" << mintemp
185  << "::avg::" << (sumtemp) / (MPI_SIZE)
186  << "::max::" << maxtemp);
187  }
188 };
189 
190 } // namespace planc
191 
192 #endif // DISTNMF_DISTNMF1D_HPP_
const double err_communication_duration() const
Definition: distnmftime.hpp:74
virtual void computeNMF()=0
const double err_compute_duration() const
Definition: distnmftime.hpp:73
void algorithm(algotype dat)
Definition: distnmf1D.hpp:174
double mpitoc(int rank)
Definition: distutils.hpp:22
MAT getRightLowRankFactor()
Definition: distnmf1D.hpp:171
void reportTime(const double temp, const std::string &reportstring)
Definition: distnmf1D.hpp:175
const UWORD globaln() const
Definition: distnmf1D.hpp:169
const unsigned int num_iterations() const
Definition: distnmf1D.hpp:166
algotype
Definition: utils.h:10
#define DISTPRINTINFO(MSG)
Definition: distutils.h:37
const bool is_compute_error() const
Definition: distnmf1D.hpp:173
void mpitic()
Definition: distutils.hpp:11
void computeError(const MAT &WtW, const MAT &HtH)
Definition: distnmf1D.hpp:122
MAT getLeftLowRankFactor()
Definition: distnmf1D.hpp:170
#define UWORD
Definition: utils.h:60
void num_iterations(int it)
Definition: distnmf1D.hpp:167
void compute_error(const uint &ce)
Definition: distnmf1D.hpp:172
#define MAT
Definition: utils.h:52
double globalW()
Definition: distnmf1D.hpp:73
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define PRINTMATINFO(A)
Definition: utils.h:63
const UWORD globalm() const
Definition: distnmf1D.hpp:168
DistNMF1D(const INPUTMATTYPE &Arows, const INPUTMATTYPE &Acols, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &mpicomm)
Definition: distnmf1D.hpp:40
#define MPI_SIZE
Definition: distutils.h:15
double globalH()
Definition: distnmf1D.hpp:95