planc
Parallel Lowrank Approximation with Non-negativity Constraints
disthals.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_DISTHALS_HPP_
4 #define DISTNMF_DISTHALS_HPP_
5 
6 #include "distnmf/aunmf.hpp"
13 namespace planc {
14 
15 template <class INPUTMATTYPE>
16 class DistHALS : public DistAUNMF<INPUTMATTYPE> {
17  protected:
26  void updateW() {
27  for (unsigned int i = 0; i < this->k; i++) {
28  // W(:,i) = max(W(:,i) * HHt_reg(i,i) + AHt(:,i) - W *
29  // HHt_reg(:,i),epsilon);
30  VEC updWi = this->W.col(i) * this->HtH(i, i) +
31  ((this->AHtij.row(i)).t() - this->W * this->HtH.col(i));
32 #ifdef MPI_VERBOSE
33  DISTPRINTINFO("b4 fixNumericalError::" << endl << updWi);
34 #endif // ifdef MPI_VERBOSE
35  fixNumericalError<VEC>(&updWi);
36 #ifdef MPI_VERBOSE
37  DISTPRINTINFO("after fixNumericalError::" << endl << updWi);
38 #endif // ifdef MPI_VERBOSE
39 
40  // W(:,i) = W(:,i)/norm(W(:,i));
41  double normWi = arma::norm(updWi, 2);
42  normWi *= normWi;
43  double globalnormWi;
44  mpitic();
45  MPI_Allreduce(&normWi, &globalnormWi, 1, MPI_DOUBLE, MPI_SUM,
46  MPI_COMM_WORLD);
47  double temp = mpitoc();
48  this->time_stats.communication_duration(temp);
49  this->time_stats.allreduce_duration(temp);
50 
51  if (globalnormWi > 0) {
52  this->W.col(i) = updWi / sqrt(globalnormWi);
53  this->H.col(i) = this->H.col(i) * sqrt(globalnormWi);
54  }
55  }
56  this->Wt = this->W.t();
57  }
58 
67  void updateH() {
68  for (unsigned int i = 0; i < this->k; i++) {
69  // H(i,:) = max(H(i,:) + WtA(i,:) - WtW_reg(i,:) * H,epsilon);
70  VEC updHi = this->H.col(i) +
71  ((this->WtAij.row(i)).t() - this->H * this->WtW.col(i));
72 #ifdef MPI_VERBOSE
73  DISTPRINTINFO("b4 fixNumericalError::" << endl << updHi);
74 #endif // ifdef MPI_VERBOSE
75  fixNumericalError<VEC>(&updHi);
76 #ifdef MPI_VERBOSE
77  DISTPRINTINFO("after fixNumericalError::" << endl << updHi);
78 #endif // ifdef MPI_VERBOSE
79  double normHi = arma::norm(updHi, 2);
80  normHi *= normHi;
81  double globalnormHi;
82  mpitic();
83  MPI_Allreduce(&normHi, &globalnormHi, 1, MPI_DOUBLE, MPI_SUM,
84  MPI_COMM_WORLD);
85  double temp = mpitoc();
86  this->time_stats.communication_duration(temp);
87  this->time_stats.allreduce_duration(temp);
88 
89  if (globalnormHi > 0) {
90  this->H.col(i) = updHi;
91  }
92  }
93  this->Ht = this->H.t();
94  }
95 
96  public:
97  DistHALS(const INPUTMATTYPE& input, const MAT& leftlowrankfactor,
98  const MAT& rightlowrankfactor, const MPICommunicator& communicator,
99  const int numkblks)
100  : DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
101  communicator, numkblks) {
102  PRINTROOT("DistHALS() constructor successful");
103  }
104 };
105 
106 } // namespace planc
107 
108 #endif // DISTNMF_DISTHALS_HPP_
double mpitoc(int rank)
Definition: distutils.hpp:22
#define DISTPRINTINFO(MSG)
Definition: distutils.h:37
DistHALS(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
Definition: disthals.hpp:97
void mpitic()
Definition: distutils.hpp:11
#define MAT
Definition: utils.h:52
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define VEC
Definition: utils.h:61