3 #ifndef NMF_AOADMM_HPP_ 4 #define NMF_AOADMM_HPP_ 35 double alpha, beta, tolerance;
42 void allocateMatrices() {
43 WtW = arma::zeros<MAT>(this->k, this->k);
44 HtH = arma::zeros<MAT>(this->k, this->k);
45 WtA = arma::zeros<MAT>(this->n, this->k);
46 AH = arma::zeros<MAT>(this->m, this->k);
49 U.zeros(size(this->W));
50 V.zeros(size(this->H));
53 Htaux.zeros(size(this->H.t()));
54 H0.zeros(size(this->H));
55 tempHtaux.zeros(size(this->H.t()));
56 Wtaux.zeros(size(this->W.t()));
57 W0.zeros(size(this->W));
58 tempWtaux.zeros(size(this->W.t()));
59 L.zeros(this->k, this->k);
77 this->normalize_by_W();
81 this->normalize_by_W();
85 unsigned int currentIteration = 0;
86 this->At = this->A.t();
92 WtA = this->W.t() * this->A;
93 WtW = this->W.t() * this->W;
94 beta = trace(WtW) / this->k;
95 beta = beta > 0 ? beta : 0.01;
98 INFO <<
"starting H Prereq for " 103 L = arma::chol(WtW,
"lower");
105 bool stop_iter =
false;
108 for (
int i = 0; i < admm_iter && !stop_iter; i++) {
111 arma::solve(arma::trimatl(L), WtA + (beta * (this->H.t() + V.t())));
112 Htaux = arma::solve(arma::trimatu(L.t()), tempHtaux);
116 this->H = this->H - V;
118 [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
119 V = V + this->H - Htaux.t();
122 double r = norm(this->H - Htaux.t(),
"fro");
123 double s = norm(this->H - H0,
"fro");
124 double normH = norm(this->H,
"fro");
125 double normV = norm(V,
"fro");
127 if (r < (tolerance * normH) && s < (tolerance * normV))
131 INFO <<
"Completed H (" << currentIteration <<
"/" 133 <<
" time =" <<
toc() << std::endl;
137 AH = this->A * this->H;
138 HtH = this->H.t() * this->H;
139 alpha = trace(HtH) / this->k;
140 alpha = alpha > 0 ? alpha : 0.01;
143 INFO <<
"starting W Prereq for " 147 L = arma::chol(HtH,
"lower");
152 for (
int i = 0; i < admm_iter && !stop_iter; i++) {
154 tempWtaux = arma::solve(arma::trimatl(L),
155 AH.t() + alpha * (this->W.t() + U.t()));
156 Wtaux = arma::solve(arma::trimatu(L.t()), tempWtaux);
160 this->W = this->W - U;
162 [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
164 U = U + this->W - Wtaux.t();
167 double r = norm(this->W - Wtaux.t(),
"fro");
168 double s = norm(this->W - W0,
"fro");
169 double normW = norm(this->W,
"fro");
170 double normU = norm(U,
"fro");
172 if (r < (tolerance * normW) && s < (tolerance * normU))
176 INFO <<
"Completed W (" << currentIteration <<
"/" 178 <<
" time =" <<
toc() << std::endl;
180 INFO <<
"Completed It (" << currentIteration <<
"/" 182 <<
" time =" <<
toc() << std::endl;
184 INFO <<
"Completed it = " << currentIteration
185 <<
" AOADMMERR=" << sqrt(this->objective_err) / this->normA
195 #endif // NMF_AOADMM_HPP_ AOADMMNMF(const T &A, const MAT &llf, const MAT &rlf)
void tic()
start the timer. easy to call as tic(); some code; double t=toc();
#define EPSILON_1EMINUS16
const unsigned int num_iterations() const
Returns the number of iterations.
AOADMMNMF(const T &A, int lowrank)
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
void computeObjectiveError()