planc
Parallel Lowrank Approximation with Non-negativity Constraints
nmf.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 #ifndef COMMON_NMF_HPP_
3 #define COMMON_NMF_HPP_
4 #include <assert.h>
5 #include <string>
6 #include "common/utils.hpp"
7 
8 // #ifndef _VERBOSE
9 // #define _VERBOSE 1;
10 // #endif
11 
12 #define NUM_THREADS 4
13 #define CONV_ERR 0.000001
14 #define NUM_STATS 9
15 
16 // #ifndef COLLECTSTATS
17 // #define COLLECTSTATS 1
18 // #endif
19 
20 namespace planc {
21 
22 // T must be a either an instance of MAT or sp_MAT
23 template <class T>
24 class NMF {
25  protected:
26  T A;
27  MAT W, H;
28  MAT Winit, Hinit;
29  UINT m, n, k;
30 
31  /*
32  * Collected statistics are
33  * iteration Htime Wtime totaltime normH normW densityH densityW relError
34  */
35  MAT stats;
36  double objective_err;
37  double normA, normW, normH;
38  double densityW, densityH;
39  bool cleared;
40  unsigned int m_num_iterations;
41  std::string input_file_name;
42  MAT errMtx; // used for error computation.
43  T A_err_sub_mtx; // used for error computation.
46  FVEC m_regW;
47  FVEC m_regH;
48 
49  void collectStats(int iteration) {
50  this->normW = arma::norm(this->W, "fro");
51  this->normH = arma::norm(this->H, "fro");
52  UVEC nnz = find(this->W > 0);
53  this->densityW = nnz.size() / (this->m * this->k);
54  nnz.clear();
55  nnz = find(this->H > 0);
56  this->densityH = nnz.size() / (this->m * this->k);
57  this->stats(iteration, 4) = this->normH;
58  this->stats(iteration, 5) = this->normW;
59  this->stats(iteration, 6) = this->densityH;
60  this->stats(iteration, 7) = this->densityW;
61  this->stats(iteration, 8) = this->objective_err;
62  }
63 
72  void applyReg(const FVEC &reg, MAT *AtA) {
73  // Frobenius norm regularization
74  if (reg(0) > 0) {
75  MAT identity = arma::eye<MAT>(this->k, this->k);
76  float lambda_l2 = reg(0);
77  (*AtA) = (*AtA) + 2 * lambda_l2 * identity;
78  }
79 
80  // L1 - norm regularization
81  if (reg(1) > 0) {
82  MAT onematrix = arma::ones<MAT>(this->k, this->k);
83  float lambda_l1 = reg(1);
84  (*AtA) = (*AtA) + 2 * lambda_l1 * onematrix;
85  }
86  }
87 
92  void normalize_by_W() {
93  MAT W_square = arma::pow(this->W, 2);
94  ROWVEC norm2 = arma::sqrt(arma::sum(W_square, 0));
95  for (unsigned int i = 0; i < this->k; i++) {
96  if (norm2(i) > 0) {
97  this->W.col(i) = this->W.col(i) / norm2(i);
98  this->H.col(i) = this->H.col(i) * norm2(i);
99  }
100  }
101  }
102 
103  private:
104  void otherInitializations() {
105  this->stats.zeros();
106  this->cleared = false;
107  this->normA = arma::norm(this->A, "fro");
108  this->m_num_iterations = 20;
109  this->objective_err = 1000000000000;
110  this->stats.resize(m_num_iterations + 1, NUM_STATS);
111  }
112 
113  public:
119  NMF(const T &input, const unsigned int rank) {
120  this->A = input;
121  this->m = A.n_rows;
122  this->n = A.n_cols;
123  this->k = rank;
124  // prime number closer to W.
125  arma::arma_rng::set_seed(89);
126  this->W = arma::randu<MAT>(m, k);
127  // prime number close to H
128  arma::arma_rng::set_seed(73);
129  this->H = arma::randu<MAT>(n, k);
130  this->m_regW = arma::zeros<FVEC>(2);
131  this->m_regH = arma::zeros<FVEC>(2);
132  normalize_by_W();
133 
134  // make the random MATrix positive
135  // absMAT<MAT>(W);
136  // absMAT<MAT>(H);
137  // other intializations
138  this->otherInitializations();
139  }
146  NMF(const T &input, const MAT &leftlowrankfactor,
147  const MAT &rightlowrankfactor) {
148  assert(leftlowrankfactor.n_cols == rightlowrankfactor.n_cols);
149  this->A = input;
150  this->W = leftlowrankfactor;
151  this->H = rightlowrankfactor;
152  this->Winit = this->W;
153  this->Hinit = this->H;
154  this->m = A.n_rows;
155  this->n = A.n_cols;
156  this->k = W.n_cols;
157  this->m_regW = arma::zeros<FVEC>(2);
158  this->m_regH = arma::zeros<FVEC>(2);
159 
160  // other initializations
161  this->otherInitializations();
162  }
163 
164  virtual void computeNMF() = 0;
165 
167  MAT getLeftLowRankFactor() { return W; }
169  MAT getRightLowRankFactor() { return H; }
170 
171  /*
172  * A is mxn
173  * Wr is mxk will be overwritten. Must be passed with values of W.
174  * Hr is nxk will be overwritten. Must be passed with values of H.
175  * All MATrices are in row major forMAT
176  * ||A-WH||_F^2 = over all nnz (a_ij - w_i h_j)^2 +
177  * over all zeros (w_i h_j)^2
178  * = over all nnz (a_ij - w_i h_j)^2 +
179  ||WH||_F^2 - over all nnz (w_i h_j)^2
180  *
181  */
182 #if 0
183  void computeObjectiveError() {
184  // 1. over all nnz (a_ij - w_i h_j)^2
185  // 2. over all nnz (w_i h_j)^2
186  // 3. Compute R of W ahd L of H through QR
187  // 4. use sgemm to compute RL
188  // 5. use slange to compute ||RL||_F^2
189  // 6. return nnzsse+nnzwh-||RL||_F^2
190  tic();
191  float nnzsse = 0;
192  float nnzwh = 0;
193  MAT Rw(this->k, this->k);
194  MAT Rh(this->k, this->k);
195  MAT Qw(this->m, this->k);
196  MAT Qh(this->n, this->k);
197  MAT RwRh(this->k, this->k);
198 
199  // #pragma omp parallel for reduction (+ : nnzsse,nnzwh)
200  for (UWORD jj = 1; jj <= this->A.n_cols; jj++) {
201  UWORD startIdx = this->A.col_ptrs[jj - 1];
202  UWORD endIdx = this->A.col_ptrs[jj];
203  UWORD col = jj - 1;
204  float nnzssecol = 0;
205  float nnzwhcol = 0;
206 
207  for (UWORD ii = startIdx; ii < endIdx; ii++) {
208  UWORD row = this->A.row_indices[ii];
209  float tempsum = 0;
210 
211  for (UWORD kk = 0; kk < k; kk++) {
212  tempsum += (this->W(row, kk) * this->H(col, kk));
213  }
214  nnzwhcol += tempsum * tempsum;
215  nnzssecol += (this->A.values[ii] - tempsum)
216  * (this->A.values[ii] - tempsum);
217  }
218  nnzsse += nnzssecol;
219  nnzwh += nnzwhcol;
220  }
221  qr_econ(Qw, Rw, this->W);
222  qr_econ(Qh, Rh, this->H);
223  RwRh = Rw * Rh.t();
224  float normWH = arma::norm(RwRh, "fro");
225  Rw.clear();
226  Rh.clear();
227  Qw.clear();
228  Qh.clear();
229  RwRh.clear();
230  INFO << "error compute time " << toc() << std::endl;
231  float fastErr = sqrt(nnzsse + (normWH * normWH - nnzwh));
232  this->objective_err = fastErr;
233 
234  // return (fastErr);
235  }
236 
237 #else // ifdef BUILD_SPARSE
239  // (init.norm_A)^2 - 2*trace(H'*(A'*W))+trace((W'*W)*(H*H'))
240  // MAT WtW = this->W.t() * this->W;
241  // MAT HtH = this->H.t() * this->H;
242  // MAT AtW = this->A.t() * this->W;
243 
244  // double sqnormA = this->normA * this->normA;
245  // double TrHtAtW = arma::trace(this->H.t() * AtW);
246  // double TrWtWHtH = arma::trace(WtW * HtH);
247 
248  // this->objective_err = sqnormA - (2 * TrHtAtW) + TrWtWHtH;
249 #ifdef _VERBOSE
250  INFO << "Entering computeObjectiveError A=" << this->A.n_rows << "x"
251  << this->A.n_cols << " W = " << this->W.n_rows << "x" << this->W.n_cols
252  << " H=" << this->H.n_rows << "x" << this->H.n_cols << std::endl;
253 #endif
254  tic();
255  // always restrict the errMtx size to fit it in memory
256  // and doesn't occupy much space.
257  // For eg., the max we can have only 3 x 10^6 elements.
258  // The number of columns must be chosen appropriately.
259  UWORD PER_SPLIT = std::ceil((3 * 1e6) / A.n_rows);
260  // UWORD PER_SPLIT = 1;
261  // always colSplit. Row split is really slow as the matrix is col major
262  // always
263  bool colSplit = true;
264  // if (this->A.n_rows > PER_SPLIT || this->A.n_cols > PER_SPLIT) {
265  uint numSplits = 1;
266  MAT Ht = this->H.t();
267  if (this->A.n_cols > PER_SPLIT) {
268  // if (this->A.n_cols < this->A.n_rows)
269  // colSplit = false;
270  if (colSplit)
271  numSplits = A.n_cols / PER_SPLIT;
272  else
273  numSplits = A.n_rows / PER_SPLIT;
274  // #ifdef _VERBOSE
275  } else {
276  PER_SPLIT = A.n_cols;
277  numSplits = 1;
278  }
279 #ifdef _VERBOSE
280  INFO << "PER_SPLIT = " << PER_SPLIT << "numSplits = " << numSplits
281  << std::endl;
282 #endif
283  // #endif
284  VEC splitErr = arma::zeros<VEC>(numSplits + 1);
285  // allocate one and never allocate again.
286  if (colSplit && errMtx.n_rows == 0 && errMtx.n_cols == 0) {
287  errMtx = arma::zeros<MAT>(A.n_rows, PER_SPLIT);
288  A_err_sub_mtx = arma::zeros<T>(A.n_rows, PER_SPLIT);
289  } else {
290  errMtx = arma::zeros<MAT>(PER_SPLIT, A.n_cols);
291  A_err_sub_mtx = arma::zeros<T>(PER_SPLIT, A.n_cols);
292  }
293  for (unsigned int i = 0; i <= numSplits; i++) {
294  UWORD beginIdx = i * PER_SPLIT;
295  UWORD endIdx = (i + 1) * PER_SPLIT - 1;
296  if (colSplit) {
297  if (endIdx > A.n_cols) endIdx = A.n_cols - 1;
298  if (beginIdx < endIdx) {
299 #ifdef _VERBOSE
300  INFO << "beginIdx=" << beginIdx << " endIdx= " << endIdx << std::endl;
301  INFO << "Ht = " << Ht.n_rows << "x" << Ht.n_cols << std::endl;
302 
303 #endif
304  errMtx = W * Ht.cols(beginIdx, endIdx);
305  A_err_sub_mtx = A.cols(beginIdx, endIdx);
306  } else if (beginIdx == endIdx && beginIdx < A.n_cols) {
307  errMtx = W * Ht.col(beginIdx);
308  A_err_sub_mtx = A.col(beginIdx);
309  }
310  } else {
311  if (endIdx > A.n_rows) endIdx = A.n_rows - 1;
312 #ifdef _VERBOSE
313  INFO << "beginIdx=" << beginIdx << " endIdx= " << endIdx << std::endl;
314 #endif
315  if (beginIdx < endIdx) {
316  A_err_sub_mtx = A.rows(beginIdx, endIdx);
317  errMtx = W.rows(beginIdx, endIdx) * Ht;
318  }
319  }
320  A_err_sub_mtx -= errMtx;
321  A_err_sub_mtx %= A_err_sub_mtx;
322  splitErr(i) = arma::accu(A_err_sub_mtx);
323  }
324  double err_time = toc();
325  INFO << "err compute time::" << err_time << std::endl;
326  this->objective_err = arma::sum(splitErr);
327  }
328 
329 #endif // ifdef BUILD_SPARSE
330  void computeObjectiveError(const T &At, const MAT &WtW, const MAT &HtH) {
331  MAT AtW = At * this->W;
332 
333  double sqnormA = this->normA * this->normA;
334  double TrHtAtW = arma::trace(this->H.t() * AtW);
335  double TrWtWHtH = arma::trace(WtW * HtH);
336 
337  this->objective_err = sqnormA - (2 * TrHtAtW) + TrWtWHtH;
338  }
340  void num_iterations(const int it) { this->m_num_iterations = it; }
342  void regW(const FVEC &iregW) { this->m_regW = iregW; }
344  void regH(const FVEC &iregH) { this->m_regH = iregH; }
346  FVEC regW() { return this->m_regW; }
348  FVEC regH() { return this->m_regH; }
350  const unsigned int num_iterations() const { return m_num_iterations; }
351 
352  ~NMF() { clear(); }
355  void clear() {
356  if (!this->cleared) {
357  this->A.clear();
358  this->W.clear();
359  this->H.clear();
360  this->stats.clear();
361  if (errMtx.n_rows != 0 && errMtx.n_cols != 0) {
362  errMtx.clear();
363  A_err_sub_mtx.clear();
364  }
365  this->cleared = true;
366  }
367  }
368 };
369 } // namespace planc
370 #endif // COMMON_NMF_HPP_
NMF(const T &input, const unsigned int rank)
Constructors with an input matrix and low rank.
Definition: nmf.hpp:119
~NMF()
Definition: nmf.hpp:352
void regW(const FVEC &iregW)
Sets the regularization on left low rank factor W.
Definition: nmf.hpp:342
void computeObjectiveError(const T &At, const MAT &WtW, const MAT &HtH)
Definition: nmf.hpp:330
void tic()
start the timer. easy to call as tic(); some code; double t=toc();
Definition: utils.hpp:42
void clear()
Clear the memory for input matrix A, right low rank factor W and left low rank factor H...
Definition: nmf.hpp:355
const unsigned int num_iterations() const
Returns the number of iterations.
Definition: nmf.hpp:350
#define FVEC
Definition: utils.h:55
#define UVEC
Definition: utils.h:58
double toc()
Definition: utils.hpp:48
MAT getRightLowRankFactor()
Returns the right low rank factor matrix H.
Definition: nmf.hpp:169
void num_iterations(const int it)
Sets number of iterations for the NMF algorithms.
Definition: nmf.hpp:340
#define INFO
Definition: utils.h:36
FVEC regH()
Returns the L2 and L1 regularization parameters of W as a vector.
Definition: nmf.hpp:348
#define UWORD
Definition: utils.h:60
NMF(const T &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor)
Constructor with initial left and right low rank factors Necessary when you want to compare algorithm...
Definition: nmf.hpp:146
virtual void computeNMF()=0
unsigned int UINT
Definition: utils.h:68
#define MAT
Definition: utils.h:52
MAT getLeftLowRankFactor()
Returns the left low rank factor matrix W.
Definition: nmf.hpp:167
#define NUM_STATS
Definition: nmf.hpp:14
FVEC regW()
Returns the L2 and L1 regularization parameters of W as a vector.
Definition: nmf.hpp:346
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define VEC
Definition: utils.h:61
void computeObjectiveError()
Definition: nmf.hpp:238
void regH(const FVEC &iregH)
Sets the regularization on right low rank H.
Definition: nmf.hpp:344