2 #ifndef COMMON_NMF_HPP_ 3 #define COMMON_NMF_HPP_ 13 #define CONV_ERR 0.000001 37 double normA, normW, normH;
38 double densityW, densityH;
40 unsigned int m_num_iterations;
41 std::string input_file_name;
49 void collectStats(
int iteration) {
50 this->normW = arma::norm(this->W,
"fro");
51 this->normH = arma::norm(this->H,
"fro");
52 UVEC nnz = find(this->W > 0);
53 this->densityW = nnz.size() / (this->m * this->k);
55 nnz = find(this->H > 0);
56 this->densityH = nnz.size() / (this->m * this->k);
57 this->stats(iteration, 4) = this->normH;
58 this->stats(iteration, 5) = this->normW;
59 this->stats(iteration, 6) = this->densityH;
60 this->stats(iteration, 7) = this->densityW;
61 this->stats(iteration, 8) = this->objective_err;
72 void applyReg(
const FVEC ®,
MAT *AtA) {
75 MAT identity = arma::eye<MAT>(this->k, this->k);
76 float lambda_l2 = reg(0);
77 (*AtA) = (*AtA) + 2 * lambda_l2 * identity;
82 MAT onematrix = arma::ones<MAT>(this->k, this->k);
83 float lambda_l1 = reg(1);
84 (*AtA) = (*AtA) + 2 * lambda_l1 * onematrix;
92 void normalize_by_W() {
93 MAT W_square = arma::pow(this->W, 2);
94 ROWVEC norm2 = arma::sqrt(arma::sum(W_square, 0));
95 for (
unsigned int i = 0; i < this->k; i++) {
97 this->W.col(i) = this->W.col(i) / norm2(i);
98 this->H.col(i) = this->H.col(i) * norm2(i);
104 void otherInitializations() {
106 this->cleared =
false;
107 this->normA = arma::norm(this->A,
"fro");
108 this->m_num_iterations = 20;
109 this->objective_err = 1000000000000;
110 this->stats.resize(m_num_iterations + 1,
NUM_STATS);
119 NMF(
const T &input,
const unsigned int rank) {
125 arma::arma_rng::set_seed(89);
126 this->W = arma::randu<MAT>(m, k);
128 arma::arma_rng::set_seed(73);
129 this->H = arma::randu<MAT>(n, k);
130 this->m_regW = arma::zeros<FVEC>(2);
131 this->m_regH = arma::zeros<FVEC>(2);
138 this->otherInitializations();
146 NMF(
const T &input,
const MAT &leftlowrankfactor,
147 const MAT &rightlowrankfactor) {
148 assert(leftlowrankfactor.n_cols == rightlowrankfactor.n_cols);
150 this->W = leftlowrankfactor;
151 this->H = rightlowrankfactor;
152 this->Winit = this->W;
153 this->Hinit = this->H;
157 this->m_regW = arma::zeros<FVEC>(2);
158 this->m_regH = arma::zeros<FVEC>(2);
161 this->otherInitializations();
193 MAT Rw(this->k, this->k);
194 MAT Rh(this->k, this->k);
195 MAT Qw(this->m, this->k);
196 MAT Qh(this->n, this->k);
197 MAT RwRh(this->k, this->k);
200 for (
UWORD jj = 1; jj <= this->A.n_cols; jj++) {
201 UWORD startIdx = this->A.col_ptrs[jj - 1];
202 UWORD endIdx = this->A.col_ptrs[jj];
207 for (
UWORD ii = startIdx; ii < endIdx; ii++) {
208 UWORD row = this->A.row_indices[ii];
211 for (
UWORD kk = 0; kk < k; kk++) {
212 tempsum += (this->W(row, kk) * this->H(col, kk));
214 nnzwhcol += tempsum * tempsum;
215 nnzssecol += (this->A.values[ii] - tempsum)
216 * (this->A.values[ii] - tempsum);
221 qr_econ(Qw, Rw, this->W);
222 qr_econ(Qh, Rh, this->H);
224 float normWH = arma::norm(RwRh,
"fro");
230 INFO <<
"error compute time " <<
toc() << std::endl;
231 float fastErr = sqrt(nnzsse + (normWH * normWH - nnzwh));
232 this->objective_err = fastErr;
237 #else // ifdef BUILD_SPARSE 250 INFO <<
"Entering computeObjectiveError A=" << this->A.n_rows <<
"x" 251 << this->A.n_cols <<
" W = " << this->W.n_rows <<
"x" << this->W.n_cols
252 <<
" H=" << this->H.n_rows <<
"x" << this->H.n_cols << std::endl;
259 UWORD PER_SPLIT = std::ceil((3 * 1e6) / A.n_rows);
263 bool colSplit =
true;
266 MAT Ht = this->H.t();
267 if (this->A.n_cols > PER_SPLIT) {
271 numSplits = A.n_cols / PER_SPLIT;
273 numSplits = A.n_rows / PER_SPLIT;
276 PER_SPLIT = A.n_cols;
280 INFO <<
"PER_SPLIT = " << PER_SPLIT <<
"numSplits = " << numSplits
284 VEC splitErr = arma::zeros<VEC>(numSplits + 1);
286 if (colSplit && errMtx.n_rows == 0 && errMtx.n_cols == 0) {
287 errMtx = arma::zeros<MAT>(A.n_rows, PER_SPLIT);
288 A_err_sub_mtx = arma::zeros<T>(A.n_rows, PER_SPLIT);
290 errMtx = arma::zeros<MAT>(PER_SPLIT, A.n_cols);
291 A_err_sub_mtx = arma::zeros<T>(PER_SPLIT, A.n_cols);
293 for (
unsigned int i = 0; i <= numSplits; i++) {
294 UWORD beginIdx = i * PER_SPLIT;
295 UWORD endIdx = (i + 1) * PER_SPLIT - 1;
297 if (endIdx > A.n_cols) endIdx = A.n_cols - 1;
298 if (beginIdx < endIdx) {
300 INFO <<
"beginIdx=" << beginIdx <<
" endIdx= " << endIdx << std::endl;
301 INFO <<
"Ht = " << Ht.n_rows <<
"x" << Ht.n_cols << std::endl;
304 errMtx = W * Ht.cols(beginIdx, endIdx);
305 A_err_sub_mtx = A.cols(beginIdx, endIdx);
306 }
else if (beginIdx == endIdx && beginIdx < A.n_cols) {
307 errMtx = W * Ht.col(beginIdx);
308 A_err_sub_mtx = A.col(beginIdx);
311 if (endIdx > A.n_rows) endIdx = A.n_rows - 1;
313 INFO <<
"beginIdx=" << beginIdx <<
" endIdx= " << endIdx << std::endl;
315 if (beginIdx < endIdx) {
316 A_err_sub_mtx = A.rows(beginIdx, endIdx);
317 errMtx = W.rows(beginIdx, endIdx) * Ht;
320 A_err_sub_mtx -= errMtx;
321 A_err_sub_mtx %= A_err_sub_mtx;
322 splitErr(i) = arma::accu(A_err_sub_mtx);
324 double err_time =
toc();
325 INFO <<
"err compute time::" << err_time << std::endl;
326 this->objective_err = arma::sum(splitErr);
329 #endif // ifdef BUILD_SPARSE 331 MAT AtW = At * this->W;
333 double sqnormA = this->normA * this->normA;
334 double TrHtAtW = arma::trace(this->H.t() * AtW);
335 double TrWtWHtH = arma::trace(WtW * HtH);
337 this->objective_err = sqnormA - (2 * TrHtAtW) + TrWtWHtH;
342 void regW(
const FVEC &iregW) { this->m_regW = iregW; }
344 void regH(
const FVEC &iregH) { this->m_regH = iregH; }
356 if (!this->cleared) {
361 if (errMtx.n_rows != 0 && errMtx.n_cols != 0) {
363 A_err_sub_mtx.clear();
365 this->cleared =
true;
370 #endif // COMMON_NMF_HPP_ NMF(const T &input, const unsigned int rank)
Constructors with an input matrix and low rank.
void regW(const FVEC &iregW)
Sets the regularization on left low rank factor W.
void computeObjectiveError(const T &At, const MAT &WtW, const MAT &HtH)
void tic()
start the timer. easy to call as tic(); some code; double t=toc();
void clear()
Clear the memory for input matrix A, right low rank factor W and left low rank factor H...
const unsigned int num_iterations() const
Returns the number of iterations.
MAT getRightLowRankFactor()
Returns the right low rank factor matrix H.
void num_iterations(const int it)
Sets number of iterations for the NMF algorithms.
FVEC regH()
Returns the L2 and L1 regularization parameters of W as a vector.
NMF(const T &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor)
Constructor with initial left and right low rank factors Necessary when you want to compare algorithm...
virtual void computeNMF()=0
MAT getLeftLowRankFactor()
Returns the left low rank factor matrix W.
FVEC regW()
Returns the L2 and L1 regularization parameters of W as a vector.
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
void computeObjectiveError()
void regH(const FVEC &iregH)
Sets the regularization on right low rank H.