planc
Parallel Lowrank Approximation with Non-negativity Constraints
hals.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef NMF_HALS_HPP_
4 #define NMF_HALS_HPP_
5 
6 #include "common/nmf.hpp"
7 
8 namespace planc {
9 
10 template <class T>
11 class HALSNMF : public NMF<T> {
12  private:
13  // Not happy with this design. However to avoid computing At again and again
14  // making this as private variable.
15  T At;
16  MAT WtW;
17  MAT HtH;
18  MAT WtA;
19  MAT AH;
20 
21  /*
22  * Collected statistics are
23  * iteration Htime Wtime totaltime normH normW densityH densityW relError
24  */
25  void allocateMatrices() {
26  WtW = arma::zeros<MAT>(this->k, this->k);
27  HtH = arma::zeros<MAT>(this->k, this->k);
28  WtA = arma::zeros<MAT>(this->n, this->k);
29  AH = arma::zeros<MAT>(this->m, this->k);
30  }
31  void freeMatrices() {
32  this->At.clear();
33  WtW.clear();
34  HtH.clear();
35  WtA.clear();
36  AH.clear();
37  }
38 
39  public:
40  HALSNMF(const T &A, int lowrank) : NMF<T>(A, lowrank) {
41  this->normalize_by_W();
42  allocateMatrices();
43  this->At = this->A.t();
44  }
45  HALSNMF(const T &A, const MAT &llf, const MAT &rlf) : NMF<T>(A, llf, rlf) {
46  this->normalize_by_W();
47  allocateMatrices();
48  this->At = this->A.t();
49  }
50  void computeNMF() {
51  unsigned int currentIteration = 0;
52  INFO << "computed transpose At=" << PRINTMATINFO(this->At) << std::endl;
53  while (currentIteration < this->num_iterations()) {
54  tic();
55  // update H
56  tic();
57  WtA = this->W.t() * this->A;
58  WtW = this->W.t() * this->W;
59  INFO << "starting H Prereq for "
60  << " took=" << toc() << PRINTMATINFO(WtW) << PRINTMATINFO(WtA)
61  << std::endl;
62  // to avoid divide by zero error.
63  tic();
64  double normConst;
65  VEC Hx;
66  for (unsigned int x = 0; x < this->k; x++) {
67  // H(i,:) = max(H(i,:) + WtA(i,:) - WtW_reg(i,:) * H,epsilon);
68  Hx = this->H.col(x) + (((WtA.row(x)).t()) - (this->H * (WtW.col(x))));
69  fixNumericalError<VEC>(&Hx);
70  normConst = norm(Hx);
71  if (normConst != 0) {
72  this->H.col(x) = Hx;
73  }
74  }
75  INFO << "Completed H (" << currentIteration << "/"
76  << this->num_iterations() << ")"
77  << " time =" << toc() << std::endl;
78  // update W;
79  tic();
80  AH = this->A * this->H;
81  HtH = this->H.t() * this->H;
82  INFO << "starting W Prereq for "
83  << " took=" << toc() << PRINTMATINFO(HtH) << PRINTMATINFO(AH)
84  << std::endl;
85  tic();
86  VEC Wx;
87  for (unsigned int x = 0; x < this->k; x++) {
88  // FVEC Wx = W(:,x) + (AHt(:,x)-W*HHt(:,x))/HHtDiag(x);
89 
90  // W(:,i) = W(:,i) * HHt_reg(i,i) + AHt(:,i) - W * HHt_reg(:,i);
91  Wx = (this->W.col(x) * HtH(x, x)) +
92  (((AH.col(x))) - (this->W * (HtH.col(x))));
93  fixNumericalError<VEC>(&Wx);
94  normConst = norm(Wx);
95  if (normConst != 0) {
96  Wx = Wx / normConst;
97  this->W.col(x) = Wx;
98  }
99  }
100  this->normalize_by_W();
101 
102  INFO << "Completed W (" << currentIteration << "/"
103  << this->num_iterations() << ")"
104  << " time =" << toc() << std::endl;
105 
106  INFO << "Completed It (" << currentIteration << "/"
107  << this->num_iterations() << ")"
108  << " time =" << toc() << std::endl;
109  this->computeObjectiveError();
110  INFO << "Completed it = " << currentIteration
111  << " HALSERR=" << sqrt(this->objective_err) / this->normA
112  << std::endl;
113  currentIteration++;
114  }
115  }
116  ~HALSNMF() {}
117 };
118 
119 } // namespace planc
120 
121 #endif // NMF_HALS_HPP_
void tic()
start the timer. easy to call as tic(); some code; double t=toc();
Definition: utils.hpp:42
const unsigned int num_iterations() const
Returns the number of iterations.
Definition: nmf.hpp:350
double toc()
Definition: utils.hpp:48
void computeNMF()
Definition: hals.hpp:50
HALSNMF(const T &A, const MAT &llf, const MAT &rlf)
Definition: hals.hpp:45
#define INFO
Definition: utils.h:36
HALSNMF(const T &A, int lowrank)
Definition: hals.hpp:40
#define MAT
Definition: utils.h:52
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define VEC
Definition: utils.h:61
#define PRINTMATINFO(A)
Definition: utils.h:63
void computeObjectiveError()
Definition: nmf.hpp:238