planc
Parallel Lowrank Approximation with Non-negativity Constraints
aoadmm.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef NMF_AOADMM_HPP_
4 #define NMF_AOADMM_HPP_
5 
6 #include "common/nmf.hpp"
7 
8 namespace planc {
9 
10 template <class T>
11 class AOADMMNMF : public NMF<T> {
12  private:
13  // Not happy with this design. However to avoid computing At again and again
14  // making this as private variable.
15  T At;
16  MAT WtW;
17  MAT HtH;
18  MAT WtA;
19  MAT AH;
20 
21  // Dual Variables
22  MAT U;
23  MAT V;
24 
25  // Auxiliary/Temporary Variables
26  MAT Htaux;
27  MAT tempHtaux;
28  MAT H0;
29  MAT Wtaux;
30  MAT tempWtaux;
31  MAT W0;
32  MAT L;
33 
34  // Hyperparameters
35  double alpha, beta, tolerance;
36  int admm_iter;
37 
38  /*
39  * Collected statistics are
40  * iteration Htime Wtime totaltime normH normW densityH densityW relError
41  */
42  void allocateMatrices() {
43  WtW = arma::zeros<MAT>(this->k, this->k);
44  HtH = arma::zeros<MAT>(this->k, this->k);
45  WtA = arma::zeros<MAT>(this->n, this->k);
46  AH = arma::zeros<MAT>(this->m, this->k);
47 
48  // Dual Variables
49  U.zeros(size(this->W));
50  V.zeros(size(this->H));
51 
52  // Auxiliary/Temporary Variables
53  Htaux.zeros(size(this->H.t()));
54  H0.zeros(size(this->H));
55  tempHtaux.zeros(size(this->H.t()));
56  Wtaux.zeros(size(this->W.t()));
57  W0.zeros(size(this->W));
58  tempWtaux.zeros(size(this->W.t()));
59  L.zeros(this->k, this->k);
60 
61  // Hyperparameters
62  alpha = 0.0;
63  beta = 0.0;
64  tolerance = 0.01;
65  admm_iter = 5;
66  }
67  void freeMatrices() {
68  this->At.clear();
69  WtW.clear();
70  HtH.clear();
71  WtA.clear();
72  AH.clear();
73  }
74 
75  public:
76  AOADMMNMF(const T &A, int lowrank) : NMF<T>(A, lowrank) {
77  this->normalize_by_W();
78  allocateMatrices();
79  }
80  AOADMMNMF(const T &A, const MAT &llf, const MAT &rlf) : NMF<T>(A, llf, rlf) {
81  this->normalize_by_W();
82  allocateMatrices();
83  }
84  void computeNMF() {
85  unsigned int currentIteration = 0;
86  this->At = this->A.t();
87  INFO << "computed transpose At=" << PRINTMATINFO(this->At) << std::endl;
88  while (currentIteration < this->num_iterations()) {
89  tic();
90  // update H
91  tic();
92  WtA = this->W.t() * this->A;
93  WtW = this->W.t() * this->W;
94  beta = trace(WtW) / this->k;
95  beta = beta > 0 ? beta : 0.01;
96  WtW.diag() += beta;
97 
98  INFO << "starting H Prereq for "
99  << " took=" << toc() << PRINTMATINFO(WtW) << PRINTMATINFO(WtA)
100  << std::endl;
101  // to avoid divide by zero error.
102  tic();
103  L = arma::chol(WtW, "lower");
104 
105  bool stop_iter = false;
106 
107  // Start ADMM loop from here
108  for (int i = 0; i < admm_iter && !stop_iter; i++) {
109  H0 = this->H;
110  tempHtaux =
111  arma::solve(arma::trimatl(L), WtA + (beta * (this->H.t() + V.t())));
112  Htaux = arma::solve(arma::trimatu(L.t()), tempHtaux);
113 
114  this->H = Htaux.t();
115  fixNumericalError<MAT>(&(this->H), EPSILON_1EMINUS16);
116  this->H = this->H - V;
117  this->H.for_each(
118  [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
119  V = V + this->H - Htaux.t();
120 
121  // Check stopping criteria
122  double r = norm(this->H - Htaux.t(), "fro");
123  double s = norm(this->H - H0, "fro");
124  double normH = norm(this->H, "fro");
125  double normV = norm(V, "fro");
126 
127  if (r < (tolerance * normH) && s < (tolerance * normV))
128  stop_iter = true;
129  }
130 
131  INFO << "Completed H (" << currentIteration << "/"
132  << this->num_iterations() << ")"
133  << " time =" << toc() << std::endl;
134 
135  // update W;
136  tic();
137  AH = this->A * this->H;
138  HtH = this->H.t() * this->H;
139  alpha = trace(HtH) / this->k;
140  alpha = alpha > 0 ? alpha : 0.01;
141  HtH.diag() += alpha;
142 
143  INFO << "starting W Prereq for "
144  << " took=" << toc() << PRINTMATINFO(HtH) << PRINTMATINFO(AH)
145  << std::endl;
146  tic();
147  L = arma::chol(HtH, "lower");
148 
149  stop_iter = false;
150 
151  // Start ADMM loop from here
152  for (int i = 0; i < admm_iter && !stop_iter; i++) {
153  W0 = this->W;
154  tempWtaux = arma::solve(arma::trimatl(L),
155  AH.t() + alpha * (this->W.t() + U.t()));
156  Wtaux = arma::solve(arma::trimatu(L.t()), tempWtaux);
157 
158  this->W = Wtaux.t();
159  fixNumericalError<MAT>(&(this->W), EPSILON_1EMINUS16);
160  this->W = this->W - U;
161  this->W.for_each(
162  [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
163 
164  U = U + this->W - Wtaux.t();
165 
166  // Check stopping criteria
167  double r = norm(this->W - Wtaux.t(), "fro");
168  double s = norm(this->W - W0, "fro");
169  double normW = norm(this->W, "fro");
170  double normU = norm(U, "fro");
171 
172  if (r < (tolerance * normW) && s < (tolerance * normU))
173  stop_iter = true;
174  }
175 
176  INFO << "Completed W (" << currentIteration << "/"
177  << this->num_iterations() << ")"
178  << " time =" << toc() << std::endl;
179 
180  INFO << "Completed It (" << currentIteration << "/"
181  << this->num_iterations() << ")"
182  << " time =" << toc() << std::endl;
183  this->computeObjectiveError();
184  INFO << "Completed it = " << currentIteration
185  << " AOADMMERR=" << sqrt(this->objective_err) / this->normA
186  << std::endl;
187  currentIteration++;
188  }
189  }
191 };
192 
193 } // namespace planc
194 
195 #endif // NMF_AOADMM_HPP_
AOADMMNMF(const T &A, const MAT &llf, const MAT &rlf)
Definition: aoadmm.hpp:80
void tic()
start the timer. easy to call as tic(); some code; double t=toc();
Definition: utils.hpp:42
#define EPSILON_1EMINUS16
Definition: utils.h:43
const unsigned int num_iterations() const
Returns the number of iterations.
Definition: nmf.hpp:350
AOADMMNMF(const T &A, int lowrank)
Definition: aoadmm.hpp:76
double toc()
Definition: utils.hpp:48
#define INFO
Definition: utils.h:36
void computeNMF()
Definition: aoadmm.hpp:84
#define MAT
Definition: utils.h:52
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTMATINFO(A)
Definition: utils.h:63
void computeObjectiveError()
Definition: nmf.hpp:238