planc
Parallel Lowrank Approximation with Non-negativity Constraints
distanlsbpp.hpp
Go to the documentation of this file.
1 /* Copyright 2016 Ramakrishnan Kannan */
2 
3 #ifndef DISTNMF_DISTANLSBPP_HPP_
4 #define DISTNMF_DISTANLSBPP_HPP_
5 
6 #include "distnmf/aunmf.hpp"
7 #include "nnls/bppnnls.hpp"
8 
14 #ifdef BUILD_CUDA
15 #define ONE_THREAD_MATRIX_SIZE 1000
16 #include <omp.h>
17 #else
18 #define ONE_THREAD_MATRIX_SIZE giventInput.n_cols + 5
19 #endif
20 
21 namespace planc {
22 
23 template <class INPUTMATTYPE>
24 class DistANLSBPP : public DistAUNMF<INPUTMATTYPE> {
25  private:
26  ROWVEC localWnorm;
27  ROWVEC Wnorm;
28 
29  void allocateMatrices() {}
30 
34  void updateOtherGivenOneMultipleRHS(const MAT& giventGiven,
35  const MAT& giventInput, MAT* othermat) {
36  UINT numThreads = (giventInput.n_cols / ONE_THREAD_MATRIX_SIZE) + 1;
37 #pragma omp parallel for schedule(dynamic)
38  for (UINT i = 0; i < numThreads; i++) {
39  UINT spanStart = i * ONE_THREAD_MATRIX_SIZE;
40  UINT spanEnd = (i + 1) * ONE_THREAD_MATRIX_SIZE - 1;
41  if (spanEnd > giventInput.n_cols - 1) {
42  spanEnd = giventInput.n_cols - 1;
43  }
44  // if it is exactly divisible, the last iteration is unnecessary.
45  BPPNNLS<MAT, VEC>* subProblem;
46  if (spanStart <= spanEnd) {
47  if (spanStart == spanEnd) {
48  subProblem = new BPPNNLS<MAT, VEC>(
49  giventGiven, (VEC)giventInput.col(spanStart), true);
50  } else { // if (spanStart < spanEnd)
51  subProblem = new BPPNNLS<MAT, VEC>(
52  giventGiven, (MAT)giventInput.cols(spanStart, spanEnd), true);
53  }
54 #ifdef MPI_VERBOSE
55 #pragma omp parallel
56  {
57  DISTPRINTINFO("Scheduling " << worh << " start=" << spanStart
58  << ", end=" << spanEnd
59  << ", tid=" << omp_get_thread_num());
60  }
61 #endif
62  subProblem->solveNNLS();
63 #ifdef MPI_VERBOSE
64 #pragma omp parallel
65  {
66  DISTPRINTINFO("completed " << worh << " start=" << spanStart
67  << ", end=" << spanEnd
68  << ", tid=" << omp_get_thread_num()
69  << " cpu=" << sched_getcpu());
70  }
71 #endif
72  if (spanStart == spanEnd) {
73  ROWVEC solVec = subProblem->getSolutionVector().t();
74  (*othermat).row(i) = solVec;
75  } else { // if (spanStart < spanEnd)
76  (*othermat).rows(spanStart, spanEnd) =
77  subProblem->getSolutionMatrix().t();
78  }
79  subProblem->clear();
80  delete subProblem;
81  }
82  }
83  }
84 
85  protected:
92  void updateW() {
93  updateOtherGivenOneMultipleRHS(this->HtH, this->AHtij, &this->W);
94  this->Wt = this->W.t();
95  }
102  void updateH() {
103  updateOtherGivenOneMultipleRHS(this->WtW, this->WtAij, &this->H);
104  this->Ht = this->H.t();
105  }
106 
107  public:
108  DistANLSBPP(const INPUTMATTYPE& input, const MAT& leftlowrankfactor,
109  const MAT& rightlowrankfactor,
110  const MPICommunicator& communicator, const int numkblks)
111  : DistAUNMF<INPUTMATTYPE>(input, leftlowrankfactor, rightlowrankfactor,
112  communicator, numkblks) {
113  localWnorm.zeros(this->k);
114  Wnorm.zeros(this->k);
115  PRINTROOT("DistANLSBPP() constructor successful");
116  }
117 
119  /*
120  tempHtH.clear();
121  tempWtW.clear();
122  tempAHtij.clear();
123  tempWtAij.clear();
124  */
125  }
126 }; // class DistANLSBPP2D
127 
128 } // namespace planc
129 
130 #endif // DISTNMF_DISTANLSBPP_HPP_
MATTYPE getSolutionMatrix()
Definition: nnls.hpp:79
#define DISTPRINTINFO(MSG)
Definition: distutils.h:37
int solveNNLS()
Definition: bppnnls.hpp:30
VECTYPE getSolutionVector()
Definition: nnls.hpp:76
#define ONE_THREAD_MATRIX_SIZE
Provides the updateW and updateH for the distributed ANLS/BPP algorithm.
Definition: distanlsbpp.hpp:18
unsigned int UINT
Definition: utils.h:68
#define MAT
Definition: utils.h:52
void clear()
Definition: nnls.hpp:82
#define ROWVEC
Definition: utils.h:54
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
Definition: ncpfactors.hpp:20
#define PRINTROOT(MSG)
Definition: distutils.h:32
#define VEC
Definition: utils.h:61
DistANLSBPP(const INPUTMATTYPE &input, const MAT &leftlowrankfactor, const MAT &rightlowrankfactor, const MPICommunicator &communicator, const int numkblks)