3 #ifndef DISTNMF_DISTHALS_HPP_ 4 #define DISTNMF_DISTHALS_HPP_ 15 template <
class INPUTMATTYPE>
27 for (
unsigned int i = 0; i < this->k; i++) {
30 VEC updWi = this->W.col(i) * this->HtH(i, i) +
31 ((this->AHtij.row(i)).t() - this->W * this->HtH.col(i));
34 #endif // ifdef MPI_VERBOSE 35 fixNumericalError<VEC>(&updWi);
38 #endif // ifdef MPI_VERBOSE 41 double normWi = arma::norm(updWi, 2);
45 MPI_Allreduce(&normWi, &globalnormWi, 1, MPI_DOUBLE, MPI_SUM,
48 this->time_stats.communication_duration(temp);
49 this->time_stats.allreduce_duration(temp);
51 if (globalnormWi > 0) {
52 this->W.col(i) = updWi / sqrt(globalnormWi);
53 this->H.col(i) = this->H.col(i) * sqrt(globalnormWi);
56 this->Wt = this->W.t();
68 for (
unsigned int i = 0; i < this->k; i++) {
70 VEC updHi = this->H.col(i) +
71 ((this->WtAij.row(i)).t() - this->H * this->WtW.col(i));
74 #endif // ifdef MPI_VERBOSE 75 fixNumericalError<VEC>(&updHi);
78 #endif // ifdef MPI_VERBOSE 79 double normHi = arma::norm(updHi, 2);
83 MPI_Allreduce(&normHi, &globalnormHi, 1, MPI_DOUBLE, MPI_SUM,
86 this->time_stats.communication_duration(temp);
87 this->time_stats.allreduce_duration(temp);
89 if (globalnormHi > 0) {
90 this->H.col(i) = updHi;
93 this->Ht = this->H.t();
97 DistHALS(
const INPUTMATTYPE& input,
const MAT& leftlowrankfactor,
100 :
DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
101 communicator, numkblks) {
102 PRINTROOT(
"DistHALS() constructor successful");
108 #endif // DISTNMF_DISTHALS_HPP_
#define DISTPRINTINFO(MSG)
DistHALS(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...