3 #ifndef NNLS_BPPNNLS_HPP_ 4 #define NNLS_BPPNNLS_HPP_ 21 template <
class MATTYPE,
class VECTYPE>
24 BPPNNLS(MATTYPE input, VECTYPE rhs,
bool prodSent =
false):
25 NNLS<MATTYPE, VECTYPE>(input, rhs, prodSent) {
27 BPPNNLS(MATTYPE input, MATTYPE RHS,
bool prodSent =
false) :
28 NNLS<MATTYPE, VECTYPE>(input, RHS, prodSent) {
33 rcIterations = solveNNLSOneRHS();
38 rcIterations = solveNNLSMultipleRHS();
49 int solveNNLSOneRHS() {
51 UINT MAX_ITERATIONS = this->q * 2;
56 for (
UINT i = 0; i < G.n_rows; i++) {
60 VECTYPE y = -this->Ctb;
62 INFO << endl <<
"C : " << this->CtC;
63 INFO <<
"b : " << this->Ctb;
64 INFO <<
"Init y:" << y;
67 UINT beta = this->q + 1;
68 unsigned int numIterations = 0;
70 bool solutionFound =
false;
72 while (numIterations < MAX_ITERATIONS) {
77 V1 = find(this->x(F) < 0);
80 STDVEC Fv1 = arma::conv_to<STDVEC>::from(F(V1));
81 STDVEC Gv2 = arma::conv_to<STDVEC>::from(G(V2));
82 std::set_union(Fv1.begin(),
86 std::inserter(Vs, Vs.begin()));
88 UVEC V = arma::conv_to<UVEC>::from(Vs);
90 INFO <<
"xf<0 : " << V1.size() << endl << V1;
91 INFO <<
"yg<0 : " << V2.size() << endl << V2;
92 INFO <<
"V :" << V.size() << endl << V;
102 if (V.size() < beta) {
111 int temp = V(V.n_rows - 1);
117 fixAllSets(F, G, V, allIdxs);
119 INFO <<
"V:" << V.size() << endl << V;
120 INFO <<
"a : " << alpha <<
", b:" << beta << endl;
121 INFO <<
"F:" << F.size() << endl << F << endl;
122 INFO <<
"G:" << G.size() << endl << G << endl;
123 INFO <<
"b4 x:" << this->x << endl;
124 INFO <<
"b4 y:" << y << endl;
132 this->x(F) = solveSymmetricLinearEquations(this->CtC(F, F),
134 VECTYPE lhs = (this->CtC * this->x);
137 INFO <<
"after x:" << this->x << endl;
138 INFO <<
"lhs y:" << lhs;
139 INFO <<
"lhs(0) " << lhs.row(0) <<
" first b" << this->Ctb.row(0)
140 <<
"diff value " << lhs.row(0) - this->Ctb.row(0) << endl;
141 INFO <<
"after y:" << y(G) << endl;
145 fixNumericalError<VECTYPE>(&this->x);
146 fixNumericalError<VECTYPE>(&y);
152 if (this->x(temp) < 0) {
153 WARN <<
"invoking lawson and hanson's fix" << endl;
160 if (numIterations >= MAX_ITERATIONS && !solutionFound) {
161 ERR <<
"max iterations passed. calling Hanson's ActivesetNNLS" 163 ERR <<
"current x initialized for Hanson's algo : " << endl
167 anls.solve(this->CtC.memptr(),
static_cast<int>(this->q),
168 this->Ctb.memptr(), this->x.memptr(), rNorm);
169 double *nnlsDual = anls.getDual();
170 INFO <<
"Activeset NNLS Dual:" << endl;
171 for (
unsigned int i = 0; i < this->q; i++) {
172 INFO << nnlsDual[i] << endl;
175 return numIterations;
182 int solveNNLSMultipleRHS() {
183 UINT currentIteration = 0;
184 UINT MAX_ITERATIONS = this->q * 2;
185 MATTYPE Y = -this->CtB;
186 UVEC Fv(this->q * this->r);
188 UVEC Gv(this->q * this->r);
189 arma::umat V(this->q, this->r);
191 IVEC alphaZeroIdxs(this->r);
192 bool solutionFound =
false;
193 for (
UINT i = 0; i < this->q * this->r; i++) {
195 allIdxs.push_back(i);
197 UVEC alpha(this->r), beta(this->r);
201 beta = beta * (this->q + 1);
203 INFO <<
"Gv :" << Gv.size() << endl << Gv;
204 INFO <<
"Rank : " << arma::rank(this->CtC) << endl;
205 INFO <<
"Condition : " << cond(this->CtC) << endl;
206 INFO <<
"a: " << endl << alpha;
207 INFO <<
"b: " << endl << beta;
208 INFO <<
"Y: " << endl << Y;
209 INFO <<
"Rank : " << arma::rank(this->CtC) << endl;
210 INFO <<
"Condition : " << cond(this->CtC) << endl;
212 while (currentIteration < MAX_ITERATIONS) {
213 UVEC V1 = find(this->X(Fv) < 0);
214 UVEC V2 = find(Y(Gv) < 0);
220 if (currentIteration == 0) {
224 INFO <<
"X(Fv)<0 : " << V1.size() << endl << V1;
225 INFO <<
"Y(Gv)<0 : " << V2.size() << endl << V2;
229 STDVEC Fvv1 = arma::conv_to<STDVEC>::from(Fv(V1));
230 STDVEC Gvv2 = arma::conv_to<STDVEC>::from(Gv(V2));
231 std::set_union(Fvv1.begin(),
235 std::inserter(Vs, Vs.begin()));
237 UVEC VIdx = arma::conv_to<UVEC>::from(Vs);
240 INFO <<
"Terminating the loop" << endl;
242 solutionFound =
true;
248 INFO <<
"V:" << V.size() << endl << V;
264 UVEC NonOptCols = find(sum(V) != 0);
266 INFO <<
"NonOptCols:" << NonOptCols.size() << NonOptCols;
268 alphaZeroIdxs.ones();
269 alphaZeroIdxs = alphaZeroIdxs * -1;
270 for (
UINT i = 0; i < NonOptCols.size(); i++) {
271 int currentIdx = NonOptCols(i);
273 if (sum(V.col(currentIdx)) < beta(currentIdx)) {
274 beta(currentIdx) = sum(V.col(currentIdx));
275 alpha(currentIdx) = 3;
277 if (alpha(currentIdx) >= 1) {
281 alpha(currentIdx) = 0;
282 UVEC temp = find(V.col(currentIdx) != 0);
284 INFO <<
"temp : " << endl << temp <<
"max :" 285 << temp.max() << endl;
287 V.col(currentIdx).zeros();
288 V(temp.max(), currentIdx) = 1;
289 alphaZeroIdxs(currentIdx) = temp.max();
293 INFO <<
"idx:" << currentIdx << endl;
294 INFO <<
"V:" << V.col(currentIdx);
295 INFO <<
"a : " << alpha(currentIdx)
296 <<
", b:" << beta(currentIdx) << endl;
301 fixAllSets(Fv, Gv, VIdx, allIdxs);
303 INFO <<
"F:" << endl << Fv << endl;
304 INFO <<
"G:" << endl << Gv << endl;
305 INFO <<
"VIdx:" << endl << VIdx << endl;
310 INFO <<
"b4 x:" << endl << this->X << endl;
311 INFO <<
"b4 y:" << endl << Y << endl;
315 arma::umat PassiveSet(this->q, this->r);
317 PassiveSet(Fv).ones();
318 UVEC FvCols = find(sum(PassiveSet) != 0);
319 this->X.cols(FvCols) = solveNormalEqComb(this->CtC,
320 this->CtB.cols(FvCols),
321 PassiveSet.cols(FvCols));
322 Y.cols(FvCols) = (this->CtC * this->X.cols(FvCols))
323 - this->CtB.cols(FvCols);
324 fixNumericalError<MATTYPE>(&this->X);
325 fixNumericalError<MATTYPE>(&Y);
329 for (
UINT i = 0; i < alphaZeroIdxs.size(); i++) {
330 if (alphaZeroIdxs(i) != -1) {
331 if (this->X(alphaZeroIdxs(i), i) < 0) {
332 WARN <<
"invoking lawson and hanson's fix for col" 333 << i <<
"it = " << currentIteration << endl;
334 Y(alphaZeroIdxs(i), i) = 0;
335 this->X(alphaZeroIdxs(i), i) = 0;
341 INFO <<
"after x:" << endl << this->X;
342 INFO <<
"after y:" << endl << Y;
346 if (currentIteration >= MAX_ITERATIONS && !solutionFound) {
347 ERR <<
"something wrong. appears to be infeasible" << endl;
349 INFO <<
"X : " << this->X.n_rows <<
"x" << this->X.n_cols
350 <<
" CtB:" << this->CtB.n_rows <<
"x" << this->CtB.n_cols
351 <<
" CtC" << this->CtC.n_rows <<
"x" << this->CtC.n_cols
352 <<
" : r=" << this->r <<
" :p=" << this->p
353 <<
" :q=" << this->q << endl;
354 std::ostringstream fileName, fileName2;
356 fileName <<
"errinputmatrix" << temp;
357 INFO <<
"input file matrix " << fileName.str() << endl;
358 this->CtC.save(fileName.str());
359 fileName2 <<
"errrhsmatrix" << temp;
360 INFO <<
"rhs file matrix " << fileName2.str() << endl;
361 this->CtB.save(fileName2.str());
365 INFO <<
"calling classical activeset" << endl;
366 for (
UINT i = 0; i < this->r; i++) {
367 UVEC V1 = find(this->X.col(i) < 0);
368 UVEC V2 = find(Y.col(i) < 0);
369 if (!V1.empty() || !V2.empty()) {
371 WARN <<
"col " << i <<
" not optimal " << endl;
372 WARN <<
"current x initialized for Hanson's algo : " 373 << endl << this->X.col(i);
377 double *currentX =
new double[this->q];
378 double *currentRHS =
new double[this->q];
379 for (
UINT j = 0; j < this->q; j++) {
380 currentX[j] = this->X(j, i);
381 currentRHS[j] = this->CtB(j, i);
383 anls.solve(this->CtC.memptr(),
static_cast<int>(this->q),
384 currentRHS, currentX, rNorm);
385 for (
UINT j = 0; j < this->q; j++) {
386 this->X(j, i) = currentX[j];
391 return currentIteration;
400 MATTYPE solveNormalEqComb(MATTYPE AtA, MATTYPE AtB, arma::umat PassSet) {
402 UVEC Pv = find(PassSet != 0);
403 UVEC anyZeros = find(PassSet == 0);
404 if (anyZeros.empty()) {
409 Z = solveSymmetricLinearEquations(AtA, AtB);
411 Z.resize(AtB.n_rows, AtB.n_cols);
413 UINT k1 = PassSet.n_cols;
418 Z(Pv) = solveSymmetricLinearEquations(AtA(Pv, Pv), AtB(Pv));
422 std::vector<UWORD> sortedIdx, beginIdx;
423 computeCorrelationScore(PassSet, sortedIdx, beginIdx);
428 for (
UINT i = 1; i < beginIdx.size(); i++) {
432 UWORD sortedBeginIdx = beginIdx[i - 1];
433 UWORD sortedEndIdx = beginIdx[i];
434 UVEC samePassiveSetCols(std::vector<UWORD>
435 (sortedIdx.begin() + sortedBeginIdx,
436 sortedIdx.begin() + sortedEndIdx));
438 UVEC currentPassiveSet = find(PassSet.col( sortedIdx[sortedBeginIdx] ) == 1);
440 INFO <<
"samePassiveSetCols:" << endl
441 << samePassiveSetCols;
442 INFO <<
"currPassiveSet : " << endl
443 << currentPassiveSet;
444 INFO <<
"AtA:" << endl
445 << AtA(currentPassiveSet, currentPassiveSet);
446 INFO <<
"AtB:" << endl
447 << AtB(currentPassiveSet, samePassiveSetCols);
449 Z(currentPassiveSet, samePassiveSetCols) =
450 solveSymmetricLinearEquations(AtA(currentPassiveSet, currentPassiveSet),
451 AtB(currentPassiveSet, samePassiveSetCols));
456 INFO <<
"Returning mat Z:" << endl << Z;
467 std::set<int> temp1, temp2;
468 STDVEC vecG = arma::conv_to<STDVEC>::from(G);
469 STDVEC vecV = arma::conv_to<STDVEC>::from(V);
470 STDVEC vecF = arma::conv_to<STDVEC>::from(F);
471 std::set_difference(vecG.begin(), vecG.end(), vecV.begin(), vecV.end(),
472 std::inserter(temp1, temp1.begin()));
473 std::set_intersection(vecV.begin(), vecV.end(), vecF.begin(), vecF.end(),
474 std::inserter(temp2, temp2.begin()));
476 std::set_union(temp1.begin(), temp1.end(),
477 temp2.begin(), temp2.end(),
478 std::inserter(Gs, Gs.begin()));
479 G = arma::conv_to<UVEC>::from(Gs);
490 std::set_difference(allIdxs.begin(), allIdxs.end(),
491 Gs.begin(), Gs.end(),
492 std::inserter(newF, newF.begin()));
498 F = arma::conv_to<UVEC>::from(newF);
501 void printSet(std::set<int> a) {
502 for (std::set<int>::iterator it = a.begin(); it != a.end(); it++) {
508 MATTYPE solveSymmetricLinearEquations(MATTYPE A, MATTYPE B) {
510 lapack_int n = A.n_cols;
511 lapack_int nrhs = B.n_cols;
512 lapack_int lda = A.n_rows;
513 lapack_int ldb = A.n_rows;
514 if (n <= 0 || nrhs <= 0) {
515 ERR <<
"something wrong in input" <<
" n=" << n
516 <<
" nrhs=" << nrhs << endl;
519 LAPACKE_dposv(LAPACK_COL_MAJOR,
'U', n, nrhs, A.memptr(), lda, B.memptr(), ldb);
520 if ((
signed int)info != 0) {
521 ERR <<
"something wrong in dpotsv call to blas info = " 522 << (
signed int)info << endl;
523 ERR <<
" A = " << A.n_rows <<
"x" << A.n_cols <<
"r(A)=" 524 << arma::rank(A) << endl << A;
525 ERR <<
" B = " << B.n_rows <<
"x" << B.n_cols << endl << B;
535 void computeCorrelationScore(arma::umat &PassSet, std::vector<UWORD> &sortedIdx,
536 std::vector<UWORD> &beginIndex) {
538 sortedIdx = sbm.sortIndex();
542 beginIndex.push_back(beginIdx);
543 for (uint i = 0; i < sortedIdx.size(); i++) {
544 if (i == sortedIdx.size() - 1 || bac(sortedIdx[i], sortedIdx[i + 1]) ==
true) {
546 beginIndex.push_back(beginIdx);
556 bool detectCycle(arma::umat &X) {
557 UVEC lastColumn = X.col(X.n_cols - 1);
558 for (uint i = 0; i < X.n_cols - 2; i++) {
559 arma::umat compVec = (X.col(i) == lastColumn);
560 if (sum(compVec.col(0)) == X.n_rows)
BPPNNLS(MATTYPE input, MATTYPE RHS, bool prodSent=false)
BPPNNLS(MATTYPE input, VECTYPE rhs, bool prodSent=false)
std::vector< int > STDVEC