3 #ifndef DISTNMF_DISTAOADMM_HPP_     4 #define DISTNMF_DISTAOADMM_HPP_    10 template <
class INPUTMATTYPE>
    39   double alpha, beta, tolerance;
    42   void allocateMatrices() {
    43     this->tempHtH.zeros(this->k, this->k);
    44     this->tempWtW.zeros(this->k, this->k);
    45     this->tempAHtij.zeros(size(this->AHtij));
    46     this->tempWtAij.zeros(size(this->WtAij));
    49     this->W = normalise(this->W, 2, 1);
    50     this->Wt = this->W.t();
    51     this->H = normalise(this->H);
    52     this->Ht = this->H.t();
    55     this->V.zeros(size(this->H));
    57     this->U.zeros(size(this->W));
    61     this->Waux.zeros(size(this->W));
    62     this->Wtaux = Waux.t();
    63     this->Haux.zeros(size(this->H));
    64     this->Htaux = Haux.t();
    72     this->L.zeros(this->k, this->k);
    73     this->Lt = this->L.t();
    74     this->tempWtaux.zeros(size(this->Wt));
    75     this->tempHtaux.zeros(size(this->Ht));
    84     alpha = trace(tempHtH) / this->k;
    85     alpha = alpha > 0 ? alpha : 0.01;
    86     tempHtH.diag() += alpha;
    87     L = arma::conv_to<MAT>::from(arma::chol(tempHtH, 
"lower"));
    90     bool stop_iter = 
false;
    93     for (
int i = 0; i < admm_iter && !stop_iter; i++) {
    97       tempAHtij = this->AHtij;
    98       tempAHtij = tempAHtij + (alpha * (this->Wt + this->Ut));
   102       tempWtaux = arma::solve(arma::trimatl(L), tempAHtij);
   104           arma::conv_to<MAT>::from(arma::solve(arma::trimatu(Lt), tempWtaux));
   110       this->Wt = this->Wt - this->Ut;
   112           [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
   113       this->W = this->Wt.t();
   116       this->Ut = this->Ut + this->Wt - this->Wtaux;
   117       this->U = this->Ut.t();
   120       double r = norm(this->Wt - this->Wtaux, 
"fro");
   124       MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
   126       this->time_stats.communication_duration(temp);
   127       this->time_stats.allreduce_duration(temp);
   128       globalr = sqrt(globalr);
   130       double s = norm(this->W - this->Waux, 
"fro");
   134       MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
   136       globals = sqrt(globals);
   138       double normW = norm(this->W, 
"fro");
   142       MPI_Allreduce(&normW, &globalnormW, 1, MPI_DOUBLE, MPI_SUM,
   145       globalnormW = sqrt(globalnormW);
   147       double normU = norm(this->U, 
"fro");
   151       MPI_Allreduce(&normU, &globalnormU, 1, MPI_DOUBLE, MPI_SUM,
   154       globalnormU = sqrt(globalnormU);
   156       if (globalr < (tolerance * globalnormW) &&
   157           globals < (tolerance * globalnormU))
   164     tempWtW = arma::conv_to<MAT>::from(this->WtW);
   165     beta = trace(tempWtW) / this->k;
   166     beta = beta > 0 ? beta : 0.01;
   167     tempWtW.diag() += beta;
   169     L = arma::chol(tempWtW, 
"lower");
   172     bool stop_iter = 
false;
   175     for (
int i = 0; i < admm_iter && !stop_iter; i++) {
   176       this->Haux = this->H;
   179       tempWtAij = this->WtAij;
   180       tempWtAij = tempWtAij + (beta * (this->Ht + this->Vt));
   184       tempHtaux = arma::solve(arma::trimatl(L), tempWtAij);
   186       Htaux = arma::solve(arma::trimatu(Lt), tempHtaux);
   191       this->Ht = this->Ht - this->Vt;
   193           [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
   194       this->H = this->Ht.t();
   197       this->Vt = this->Vt + this->Ht - this->Htaux;
   198       this->V = this->Vt.t();
   201       double r = norm(this->Ht - this->Htaux, 
"fro");
   205       MPI_Allreduce(&r, &globalr, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
   207       this->time_stats.communication_duration(temp);
   208       this->time_stats.allreduce_duration(temp);
   209       globalr = sqrt(globalr);
   211       double s = norm(this->H - this->Haux, 
"fro");
   215       MPI_Allreduce(&s, &globals, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
   217       globals = sqrt(globals);
   219       double normH = norm(this->H, 
"fro");
   223       MPI_Allreduce(&normH, &globalnormH, 1, MPI_DOUBLE, MPI_SUM,
   226       globalnormH = sqrt(globalnormH);
   228       double normV = norm(this->V, 
"fro");
   232       MPI_Allreduce(&normV, &globalnormV, 1, MPI_DOUBLE, MPI_SUM,
   235       globalnormV = sqrt(globalnormV);
   237       if (globalr < (tolerance * globalnormH) &&
   238           globals < (tolerance * globalnormV))
   247       : 
DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
   248                                 communicator, numkblks) {
   249     localWnorm.zeros(this->k);
   250     Wnorm.zeros(this->k);
   252     PRINTROOT(
"DistAOADMM() constructor successful");
   267 #endif  // DISTNMF_DISTAOADMM_HPP_ 
#define EPSILON_1EMINUS16
 
DistAOADMM(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)
 
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...