planc
Parallel Lowrank Approximation with Non-negativity Constraints
distaoadmm.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_DISTAOADMM_HPP_
4 #define DISTNMF_DISTAOADMM_HPP_
5 
6 #include "distnmf/aunmf.hpp"
7 
8 namespace planc {
9 
10 template <class INPUTMATTYPE>
11 class DistAOADMM : public DistAUNMF<INPUTMATTYPE> {
12  private:
13  MAT tempHtH;
14  MAT tempWtW;
15  MAT tempAHtij;
16  MAT tempWtAij;
17  ROWVEC localWnorm;
18  ROWVEC Wnorm;
19 
20  // Dual Variables
21  MAT U;
22  MAT Ut;
23  MAT V;
24  MAT Vt;
25 
26  // Auxiliary Variables
27  MAT Haux;
28  MAT Htaux;
29  MAT Waux;
30  MAT Wtaux;
31 
32  // Cholesky Variables
33  MAT L;
34  MAT Lt;
35  MAT tempHtaux;
36  MAT tempWtaux;
37 
38  // Hyperparameters
39  double alpha, beta, tolerance;
40  int admm_iter;
41 
42  void allocateMatrices() {
43  this->tempHtH.zeros(this->k, this->k);
44  this->tempWtW.zeros(this->k, this->k);
45  this->tempAHtij.zeros(size(this->AHtij));
46  this->tempWtAij.zeros(size(this->WtAij));
47 
48  // Normalise W, H
49  this->W = normalise(this->W, 2, 1);
50  this->Wt = this->W.t();
51  this->H = normalise(this->H);
52  this->Ht = this->H.t();
53 
54  // Dual Variables
55  this->V.zeros(size(this->H));
56  this->Vt = V.t();
57  this->U.zeros(size(this->W));
58  this->Ut = U.t();
59 
60  // Auxiliary Variables
61  this->Waux.zeros(size(this->W));
62  this->Wtaux = Waux.t();
63  this->Haux.zeros(size(this->H));
64  this->Htaux = Haux.t();
65 
66  // Hyperparameters
67  alpha = 0.0;
68  beta = 0.0;
69  tolerance = 0.01;
70  admm_iter = 5;
71 
72  this->L.zeros(this->k, this->k);
73  this->Lt = this->L.t();
74  this->tempWtaux.zeros(size(this->Wt));
75  this->tempHtaux.zeros(size(this->Ht));
76  }
77 
78  protected:
80  void updateW() {
81  // Calculate modified Gram Matrix
82  // tempHtH = arma::conv_to<MAT >::from(this->HtH);
83  tempHtH = this->HtH;
84  alpha = trace(tempHtH) / this->k;
85  alpha = alpha > 0 ? alpha : 0.01;
86  tempHtH.diag() += alpha;
87  L = arma::conv_to<MAT>::from(arma::chol(tempHtH, "lower"));
88  Lt = L.t();
89 
90  bool stop_iter = false;
91 
92  // Start ADMM loop from here
93  for (int i = 0; i < admm_iter && !stop_iter; i++) {
94  this->Waux = this->W;
95 
96  // tempAHtij = arma::conv_to<MAT >::from(this->AHtij);
97  tempAHtij = this->AHtij;
98  tempAHtij = tempAHtij + (alpha * (this->Wt + this->Ut));
99 
100  // Solve least squares
101  // tempWtaux = arma::conv_to<MAT >::from(
102  tempWtaux = arma::solve(arma::trimatl(L), tempAHtij);
103  Wtaux =
104  arma::conv_to<MAT>::from(arma::solve(arma::trimatu(Lt), tempWtaux));
105 
106  // Update W
107  // this->Wt = arma::conv_to<MAT >::from(Wtaux);
108  this->Wt = Wtaux;
109  fixNumericalError<MAT>(&(this->Wt), EPSILON_1EMINUS16);
110  this->Wt = this->Wt - this->Ut;
111  this->Wt.for_each(
112  [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
113  this->W = this->Wt.t();
114 
115  // Update Dual Variable
116  this->Ut = this->Ut + this->Wt - this->Wtaux;
117  this->U = this->Ut.t();
118 
119  // Check stopping criteria
120  double r = norm(this->Wt - this->Wtaux, "fro");
121  r *= r;
122  double globalr;
123  mpitic();
124  MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
125  double temp = mpitoc();
126  this->time_stats.communication_duration(temp);
127  this->time_stats.allreduce_duration(temp);
128  globalr = sqrt(globalr);
129 
130  double s = norm(this->W - this->Waux, "fro");
131  s *= s;
132  double globals;
133  mpitic();
134  MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
135  temp = mpitoc();
136  globals = sqrt(globals);
137 
138  double normW = norm(this->W, "fro");
139  normW *= normW;
140  double globalnormW;
141  mpitic();
142  MPI_Allreduce(&normW, &globalnormW, 1, MPI_DOUBLE, MPI_SUM,
143  MPI_COMM_WORLD);
144  temp = mpitoc();
145  globalnormW = sqrt(globalnormW);
146 
147  double normU = norm(this->U, "fro");
148  normU *= normU;
149  double globalnormU;
150  mpitic();
151  MPI_Allreduce(&normU, &globalnormU, 1, MPI_DOUBLE, MPI_SUM,
152  MPI_COMM_WORLD);
153  temp = mpitoc();
154  globalnormU = sqrt(globalnormU);
155 
156  if (globalr < (tolerance * globalnormW) &&
157  globals < (tolerance * globalnormU))
158  stop_iter = true;
159  }
160  }
162  void updateH() {
163  // Calculate the Gram Matrix
164  tempWtW = arma::conv_to<MAT>::from(this->WtW);
165  beta = trace(tempWtW) / this->k;
166  beta = beta > 0 ? beta : 0.01;
167  tempWtW.diag() += beta;
168  // L = arma::conv_to<MAT >::from(arma::chol(tempWtW, "lower"));
169  L = arma::chol(tempWtW, "lower");
170  Lt = L.t();
171 
172  bool stop_iter = false;
173 
174  // Start ADMM loop from here
175  for (int i = 0; i < admm_iter && !stop_iter; i++) {
176  this->Haux = this->H;
177 
178  // tempWtAij = arma::conv_to<MAT >::from(this->WtAij);
179  tempWtAij = this->WtAij;
180  tempWtAij = tempWtAij + (beta * (this->Ht + this->Vt));
181 
182  // Solve least squares
183  // tempHtaux = arma::conv_to<MAT >::from(
184  tempHtaux = arma::solve(arma::trimatl(L), tempWtAij);
185  // Htaux = arma::conv_to<MAT >::from(
186  Htaux = arma::solve(arma::trimatu(Lt), tempHtaux);
187  // Update H
188  // this->Ht = arma::conv_to<MAT >::from(Htaux);
189  this->Ht = Htaux;
190  fixNumericalError<MAT>(&(this->Ht), EPSILON_1EMINUS16);
191  this->Ht = this->Ht - this->Vt;
192  this->Ht.for_each(
193  [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
194  this->H = this->Ht.t();
195 
196  // Update Dual Variable
197  this->Vt = this->Vt + this->Ht - this->Htaux;
198  this->V = this->Vt.t();
199 
200  // Check stopping criteria
201  double r = norm(this->Ht - this->Htaux, "fro");
202  r *= r;
203  double globalr;
204  mpitic();
205  MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
206  double temp = mpitoc();
207  this->time_stats.communication_duration(temp);
208  this->time_stats.allreduce_duration(temp);
209  globalr = sqrt(globalr);
210 
211  double s = norm(this->H - this->Haux, "fro");
212  s *= s;
213  double globals;
214  mpitic();
215  MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
216  temp = mpitoc();
217  globals = sqrt(globals);
218 
219  double normH = norm(this->H, "fro");
220  normH *= normH;
221  double globalnormH;
222  mpitic();
223  MPI_Allreduce(&normH, &globalnormH, 1, MPI_DOUBLE, MPI_SUM,
224  MPI_COMM_WORLD);
225  temp = mpitoc();
226  globalnormH = sqrt(globalnormH);
227 
228  double normV = norm(this->V, "fro");
229  normV *= normV;
230  double globalnormV;
231  mpitic();
232  MPI_Allreduce(&normV, &globalnormV, 1, MPI_DOUBLE, MPI_SUM,
233  MPI_COMM_WORLD);
234  temp = mpitoc();
235  globalnormV = sqrt(globalnormV);
236 
237  if (globalr < (tolerance * globalnormH) &&
238  globals < (tolerance * globalnormV))
239  stop_iter = true;
240  }
241  }
242 
243  public:
244  DistAOADMM(const INPUTMATTYPE &input, const MAT &leftlowrankfactor,
245  const MAT &rightlowrankfactor, const MPICommunicator &communicator,
246  const int numkblks)
247  : DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
248  communicator, numkblks) {
249  localWnorm.zeros(this->k);
250  Wnorm.zeros(this->k);
251  allocateMatrices();
252  PRINTROOT("DistAOADMM() constructor successful");
253  }
254 
256  /*
257  tempHtH.clear();
258  tempWtW.clear();
259  tempAHtij.clear();
260  tempWtAij.clear();
261  */
262  }
263 }; // class DistAOADMM2D
264 
265 } // namespace planc
266 
267 #endif // DISTNMF_DISTAOADMM_HPP_
double mpitoc(int rank)
Definition: distutils.hpp:22
#define EPSILON_1EMINUS16
Definition: utils.h:43
void mpitic()
Definition: distutils.hpp:11
DistAOADMM(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
Definition: distaoadmm.hpp:244
#define MAT
Definition: utils.h:52
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32