3 #ifndef DISTNMF_DISTAOADMM_HPP_ 4 #define DISTNMF_DISTAOADMM_HPP_ 10 template <
class INPUTMATTYPE>
39 double alpha, beta, tolerance;
42 void allocateMatrices() {
43 this->tempHtH.zeros(this->k, this->k);
44 this->tempWtW.zeros(this->k, this->k);
45 this->tempAHtij.zeros(size(this->AHtij));
46 this->tempWtAij.zeros(size(this->WtAij));
49 this->W = normalise(this->W, 2, 1);
50 this->Wt = this->W.t();
51 this->H = normalise(this->H);
52 this->Ht = this->H.t();
55 this->V.zeros(size(this->H));
57 this->U.zeros(size(this->W));
61 this->Waux.zeros(size(this->W));
62 this->Wtaux = Waux.t();
63 this->Haux.zeros(size(this->H));
64 this->Htaux = Haux.t();
72 this->L.zeros(this->k, this->k);
73 this->Lt = this->L.t();
74 this->tempWtaux.zeros(size(this->Wt));
75 this->tempHtaux.zeros(size(this->Ht));
84 alpha = trace(tempHtH) / this->k;
85 alpha = alpha > 0 ? alpha : 0.01;
86 tempHtH.diag() += alpha;
87 L = arma::conv_to<MAT>::from(arma::chol(tempHtH,
"lower"));
90 bool stop_iter =
false;
93 for (
int i = 0; i < admm_iter && !stop_iter; i++) {
97 tempAHtij = this->AHtij;
98 tempAHtij = tempAHtij + (alpha * (this->Wt + this->Ut));
102 tempWtaux = arma::solve(arma::trimatl(L), tempAHtij);
104 arma::conv_to<MAT>::from(arma::solve(arma::trimatu(Lt), tempWtaux));
110 this->Wt = this->Wt - this->Ut;
112 [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
113 this->W = this->Wt.t();
116 this->Ut = this->Ut + this->Wt - this->Wtaux;
117 this->U = this->Ut.t();
120 double r = norm(this->Wt - this->Wtaux,
"fro");
124 MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
126 this->time_stats.communication_duration(temp);
127 this->time_stats.allreduce_duration(temp);
128 globalr = sqrt(globalr);
130 double s = norm(this->W - this->Waux,
"fro");
134 MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
136 globals = sqrt(globals);
138 double normW = norm(this->W,
"fro");
142 MPI_Allreduce(&normW, &globalnormW, 1, MPI_DOUBLE, MPI_SUM,
145 globalnormW = sqrt(globalnormW);
147 double normU = norm(this->U,
"fro");
151 MPI_Allreduce(&normU, &globalnormU, 1, MPI_DOUBLE, MPI_SUM,
154 globalnormU = sqrt(globalnormU);
156 if (globalr < (tolerance * globalnormW) &&
157 globals < (tolerance * globalnormU))
164 tempWtW = arma::conv_to<MAT>::from(this->WtW);
165 beta = trace(tempWtW) / this->k;
166 beta = beta > 0 ? beta : 0.01;
167 tempWtW.diag() += beta;
169 L = arma::chol(tempWtW,
"lower");
172 bool stop_iter =
false;
175 for (
int i = 0; i < admm_iter && !stop_iter; i++) {
176 this->Haux = this->H;
179 tempWtAij = this->WtAij;
180 tempWtAij = tempWtAij + (beta * (this->Ht + this->Vt));
184 tempHtaux = arma::solve(arma::trimatl(L), tempWtAij);
186 Htaux = arma::solve(arma::trimatu(Lt), tempHtaux);
191 this->Ht = this->Ht - this->Vt;
193 [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
194 this->H = this->Ht.t();
197 this->Vt = this->Vt + this->Ht - this->Htaux;
198 this->V = this->Vt.t();
201 double r = norm(this->Ht - this->Htaux,
"fro");
205 MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
207 this->time_stats.communication_duration(temp);
208 this->time_stats.allreduce_duration(temp);
209 globalr = sqrt(globalr);
211 double s = norm(this->H - this->Haux,
"fro");
215 MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
217 globals = sqrt(globals);
219 double normH = norm(this->H,
"fro");
223 MPI_Allreduce(&normH, &globalnormH, 1, MPI_DOUBLE, MPI_SUM,
226 globalnormH = sqrt(globalnormH);
228 double normV = norm(this->V,
"fro");
232 MPI_Allreduce(&normV, &globalnormV, 1, MPI_DOUBLE, MPI_SUM,
235 globalnormV = sqrt(globalnormV);
237 if (globalr < (tolerance * globalnormH) &&
238 globals < (tolerance * globalnormV))
247 :
DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
248 communicator, numkblks) {
249 localWnorm.zeros(this->k);
250 Wnorm.zeros(this->k);
252 PRINTROOT(
"DistAOADMM() constructor successful");
267 #endif // DISTNMF_DISTAOADMM_HPP_
#define EPSILON_1EMINUS16
DistAOADMM(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...