2 #ifndef DISTNTF_DISTAUNTF_HPP_ 3 #define DISTNTF_DISTAUNTF_HPP_ 29 #define TENSOR_LOCAL_DIM (m_input_tensor.dimensions()) 30 #define TENSOR_LOCAL_NUMEL (m_input_tensor.numel()) 44 MAT *ncp_local_mttkrp_t;
48 virtual MAT update(
int current_mode) = 0;
51 const Tensor &m_input_tensor;
57 MAT factor_local_grams;
58 MAT *factor_global_grams;
61 const unsigned int m_low_rank_k;
62 const unsigned int m_modes;
64 const UVEC m_global_dims;
65 const UVEC m_factor_local_dims;
66 unsigned int m_num_it;
67 unsigned int current_mode;
70 bool m_enable_dim_tree;
71 unsigned int m_current_it;
76 std::vector<bool> m_stale_mttkrp;
81 double m_global_sqnorm_A;
82 MAT hadamard_all_grams;
92 void update_global_gram(
const int current_mode) {
96 MAT H = m_local_ncp_factors.
factor(current_mode);
97 factor_local_grams = H.t() * H;
102 factor_global_grams[current_mode].zeros();
105 MPI_Allreduce(factor_local_grams.memptr(),
106 factor_global_grams[current_mode].memptr(),
107 this->m_low_rank_k * this->m_low_rank_k, MPI_DOUBLE, MPI_SUM,
110 applyReg(this->m_regularizers(current_mode * 2),
111 this->m_regularizers(current_mode * 2 + 1),
112 &(factor_global_grams[current_mode]));
124 void applyReg(
float lambda_l2,
float lambda_l1,
MAT *AtA) {
127 MAT identity = arma::eye<MAT>(this->m_low_rank_k, this->m_low_rank_k);
128 (*AtA) = (*AtA) + 2 * lambda_l2 * identity;
133 MAT onematrix = arma::ones<MAT>(this->m_low_rank_k, this->m_low_rank_k);
134 (*AtA) = (*AtA) + 2 * lambda_l1 * onematrix;
144 void gram_hadamard(
unsigned int current_mode) {
147 for (
unsigned int i = 0; i < m_modes; i++) {
148 if (i != current_mode) {
150 global_gram %= factor_global_grams[i];
164 void gather_ncp_factor(
const int current_mode) {
165 m_gathered_ncp_factors_t.
factor(current_mode).zeros();
176 MPI_Comm current_slice_comm = this->m_mpicomm.
slice(current_mode);
178 MPI_Comm_size(current_slice_comm, &slice_size);
181 int sendcnt = m_nls_sizes[current_mode] * m_low_rank_k;
184 std::vector<int> recvgathercnt(slice_size, 0);
185 std::vector<int> recvgatherdispl(slice_size, 0);
187 int dimsize = m_factor_local_dims[current_mode];
188 for (
int i = 0; i < slice_size; i++) {
189 recvgathercnt[i] =
itersplit(dimsize, slice_size, i) * m_low_rank_k;
190 recvgatherdispl[i] =
startidx(dimsize, slice_size, i) * m_low_rank_k;
193 #ifdef DISTNTF_VERBOSE 194 MPI_Comm current_fiber_comm = this->m_mpicomm.
fiber(current_mode);
197 MPI_Comm_size(current_fiber_comm, &fiber_size);
198 MPI_Comm_size(current_slice_comm, &slice_size);
200 << current_mode <<
"::fiber comm size::" << fiber_size
201 <<
"::my_global_rank::" <<
MPI_RANK <<
"::my_slice_rank::" 203 <<
"::my_fiber_rank::" 205 <<
"::sendcnt::" << sendcnt <<
"::gathered factor size::" 206 << m_gathered_ncp_factors_t.
factor(current_mode).n_elem);
209 MPI_Allgatherv(m_local_ncp_factors_t.
factor(current_mode).memptr(), sendcnt,
211 m_gathered_ncp_factors_t.
factor(current_mode).memptr(),
212 &recvgathercnt[0], &recvgatherdispl[0], MPI_DOUBLE,
220 #ifdef DISTNTF_VERBOSE 223 << m_local_ncp_factors_t.
factor(current_mode) << std::endl
224 <<
" gathered factor::" << std::endl
225 << m_gathered_ncp_factors_t.
factor(current_mode));
229 m_gathered_ncp_factors.
set(
230 current_mode, m_gathered_ncp_factors_t.
factor(current_mode).t());
243 void distmttkrp(
const int ¤t_mode) {
245 if (!this->m_enable_dim_tree) {
248 &ncp_krp[current_mode]);
254 if (this->m_enable_dim_tree) {
255 double multittv_time = 0;
256 double mttkrp_time = 0;
258 ncp_mttkrp_t[current_mode].memptr(),
false,
259 multittv_time, mttkrp_time);
267 m_input_tensor.
mttkrp(current_mode, ncp_krp[current_mode],
268 &ncp_mttkrp_t[current_mode]);
284 MPI_Comm current_slice_comm = this->m_mpicomm.
slice(current_mode);
287 MPI_Comm_size(current_slice_comm, &slice_size);
288 slice_rank = this->m_mpicomm.
slice_rank(current_mode);
289 std::vector<int> recvmttkrpsize(slice_size);
290 int dimsize = m_factor_local_dims[current_mode];
291 for (
int i = 0; i < slice_size; i++) {
292 recvmttkrpsize[i] =
itersplit(dimsize, slice_size, i) * m_low_rank_k;
294 #ifdef DISTNTF_VERBOSE 295 MPI_Comm current_fiber_comm = this->m_mpicomm.
fiber(current_mode);
297 MPI_Comm_size(current_fiber_comm, &fiber_size);
299 << current_mode <<
"::slice comm size::" << slice_size
300 <<
"::fiber comm size::" << fiber_size
301 <<
"::my_global_rank::" <<
MPI_RANK <<
"::my_slice_rank::" 303 <<
"::my_fiber_rank::" 305 <<
"::mttkrp_size::" << ncp_mttkrp_t[current_mode].n_elem
306 <<
"::local_mttkrp_size::" 307 << ncp_local_mttkrp_t[current_mode].n_elem);
309 ncp_local_mttkrp_t[current_mode].zeros();
311 MPI_Reduce_scatter(ncp_mttkrp_t[current_mode].memptr(),
312 ncp_local_mttkrp_t[current_mode].memptr(),
313 &recvmttkrpsize[0], MPI_DOUBLE, MPI_SUM,
318 #ifdef DISTNTF_VERBOSE 322 this->m_stale_mttkrp[current_mode] =
false;
325 void allocateMatrices() {
327 if (!m_enable_dim_tree) {
328 ncp_krp =
new MAT[m_modes];
330 ncp_mttkrp_t =
new MAT[m_modes];
331 ncp_local_mttkrp_t =
new MAT[m_modes];
332 factor_global_grams =
new MAT[m_modes];
333 factor_local_grams.zeros(this->m_low_rank_k, this->m_low_rank_k);
334 global_gram.ones(this->m_low_rank_k, this->m_low_rank_k);
335 UWORD current_size = 0;
336 for (
unsigned int i = 0; i < m_modes; i++) {
338 if (!m_enable_dim_tree) {
339 ncp_krp[i] = arma::zeros(current_size, this->m_low_rank_k);
342 ncp_local_mttkrp_t[i] = arma::zeros(m_local_ncp_factors.
factor(i).n_cols,
343 m_local_ncp_factors.
factor(i).n_rows);
344 factor_global_grams[i] =
345 arma::zeros(this->m_low_rank_k, this->m_low_rank_k);
349 void freeMatrices() {
350 for (
unsigned int i = 0; i < m_modes; i++) {
351 if (!m_enable_dim_tree) {
354 ncp_mttkrp_t[i].clear();
355 ncp_local_mttkrp_t[i].clear();
356 factor_global_grams[i].clear();
358 if (!m_enable_dim_tree) {
361 delete[] ncp_mttkrp_t;
362 delete[] ncp_local_mttkrp_t;
363 delete[] factor_global_grams;
366 void reportTime(
const double temp,
const std::string &reportstring) {
367 double mintemp, maxtemp, sumtemp;
368 MPI_Allreduce(&temp, &maxtemp, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD);
369 MPI_Allreduce(&temp, &mintemp, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD);
370 MPI_Allreduce(&temp, &sumtemp, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
372 <<
"::k::" << this->m_low_rank_k <<
"::SIZE::" <<
MPI_SIZE 373 <<
"::algo::" << this->m_updalgo <<
"::root::" << temp
374 <<
"::min::" << mintemp <<
"::avg::" << (sumtemp) / (
MPI_SIZE)
375 <<
"::max::" << maxtemp);
387 void update_factor_mode(
const unsigned int current_mode,
const MAT &
factor) {
388 m_local_ncp_factors.
set(current_mode,
factor);
389 m_local_ncp_factors.distributed_normalize(current_mode);
390 MAT factor_t = m_local_ncp_factors.
factor(current_mode).t();
391 m_local_ncp_factors_t.
set(current_mode, factor_t);
394 update_global_gram(current_mode);
396 gather_ncp_factor(current_mode);
397 if (this->m_enable_dim_tree) {
401 for (
unsigned int mode = 0; mode < this->m_modes; mode++) {
402 if (mode != current_mode) this->m_stale_mttkrp[mode] =
true;
406 virtual void accelerate() {}
408 void generateReport() {
409 MPI_Barrier(MPI_COMM_WORLD);
410 this->reportTime(this->time_stats.
duration(),
"total_d");
416 "total_reducescatter");
417 this->reportTime(this->time_stats.
gram_duration(),
"total_gram");
418 this->reportTime(this->time_stats.
krp_duration(),
"total_krp");
421 this->reportTime(this->time_stats.
nnls_duration(),
"total_nnls");
422 if (this->m_compute_error) {
424 "total_err_compute");
426 "total_err_communication");
432 const UVEC &i_global_dims,
const UVEC &i_local_dims,
433 const UVEC &i_nls_sizes,
const UVEC &i_nls_idxs,
435 : m_mpicomm(i_mpicomm),
436 m_nls_sizes(i_nls_sizes),
437 m_nls_idxs(i_nls_idxs),
438 m_local_ncp_factors(i_nls_sizes, i_k, false),
439 m_local_ncp_factors_t(i_nls_sizes, i_k, true),
440 m_input_tensor(i_tensor),
441 m_gathered_ncp_factors(i_tensor.dimensions(), i_k, false),
442 m_gathered_ncp_factors_t(i_tensor.dimensions(), i_k, true),
444 m_modes(m_input_tensor.
modes()),
446 m_global_dims(i_global_dims),
447 m_factor_local_dims(i_local_dims),
448 time_stats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) {
449 this->m_compute_error =
false;
450 this->m_enable_dim_tree =
false;
451 this->m_accelerated =
false;
453 this->m_rel_error = 1.0;
456 m_local_ncp_factors.
randu(149 * i_mpicomm.
rank() + 103);
457 m_local_ncp_factors.distributed_normalize();
458 for (
unsigned int i = 0; i < this->m_modes; i++) {
459 MAT current_factor = arma::trans(m_local_ncp_factors.
factor(i));
460 m_local_ncp_factors_t.
set(i, current_factor);
461 this->m_stale_mttkrp.push_back(
true);
464 m_gathered_ncp_factors.
trans(m_gathered_ncp_factors_t);
466 double normA = i_tensor.
norm();
467 MPI_Allreduce(&normA, &this->m_global_sqnorm_A, 1, MPI_DOUBLE, MPI_SUM,
471 << m_nls_sizes <<
"::NLS start indices::" << m_nls_idxs);
475 if (this->m_enable_dim_tree) {
482 size_t modes()
const {
return this->m_modes; }
484 size_t rank()
const {
return this->m_low_rank_k; }
489 this->m_compute_error = i_error;
491 arma::ones<MAT>(this->m_low_rank_k, this->m_low_rank_k);
502 this->m_enable_dim_tree = i_dim_tree;
503 if (this->m_enable_dim_tree) {
504 if (this->ncp_krp != NULL) {
505 for (
unsigned int i = 0; i < m_modes; i++) {
514 this->m_accelerated = set_acceleration;
515 this->m_compute_error =
true;
519 return this->m_stale_mttkrp[current_mode];
531 for (
unsigned int i = 0; i < m_modes; i++) {
532 update_factor_mode(i, new_factors.
factor(i));
535 for (
unsigned int i = 0; i < m_modes; i++) {
536 update_factor_mode(i, new_factors.
factor(i).t());
552 void factor(
int mode,
double *factor_matrix) {
553 gather_ncp_factor(mode);
554 int sendcnt = m_gathered_ncp_factors_t.
factor(mode).n_elem;
555 int fiber_size = this->m_mpicomm.
proc_grids()[mode];
556 int global_size = this->m_global_dims[mode];
557 std::vector<int> recvcnts(fiber_size, 0);
558 std::vector<int> displs(fiber_size, 0);
560 DISTPRINTINFO(
"Collecting mode::" << mode <<
"::sendcnt::" << sendcnt
561 <<
"::fiber_size::" << fiber_size
562 <<
"::global_size::" << global_size);
565 for (
int i = 0; i < fiber_size; i++) {
566 recvcnts[i] =
itersplit(global_size, fiber_size, i) * m_low_rank_k;
567 displs[i] =
startidx(global_size, fiber_size, i) * m_low_rank_k;
570 MPI_Gatherv(m_gathered_ncp_factors_t.
factor(mode).memptr(), sendcnt,
571 MPI_DOUBLE, factor_matrix, &recvcnts[0], &displs[0], MPI_DOUBLE,
574 0, this->m_mpicomm.
fiber(mode));
581 for (
unsigned int i = 1; i < m_modes; i++) {
582 update_global_gram(i);
583 gather_ncp_factor(i);
585 if (this->m_enable_dim_tree) {
588 size_t split_criteria = arma::prod(m_input_tensor.
dimensions());
589 split_criteria = std::round(std::sqrt(split_criteria));
592 while (temp_cum_prod(split_mode) < split_criteria) {
595 PRINTROOT(
"KDT Split Mode::" << split_mode
596 <<
"::split criteria::" << split_criteria
597 <<
"::cum prod::" << std::endl
598 << temp_cum_prod << std::endl);
601 if (split_mode > 0) {
602 size_t current_left = temp_cum_prod(split_mode);
603 size_t good_criteria = temp_cum_prod(temp_cum_prod.n_rows - 1) /
604 temp_cum_prod(split_mode - 1);
605 if (current_left > good_criteria) split_mode--;
607 << split_mode <<
"::split criteria::" << split_criteria
608 <<
"::numerator::" << temp_cum_prod(temp_cum_prod.n_rows - 1)
609 <<
"::good_criteria::" << good_criteria
610 <<
"::current_left::" << current_left << std::endl
611 <<
"::cum prod::" << std::endl
612 << temp_cum_prod << std::endl);
617 #ifdef DISTNTF_VERBOSE 619 this->m_local_ncp_factors.
print();
621 this->m_local_ncp_factors_t.
print();
623 this->m_gathered_ncp_factors.
print();
625 for (this->m_current_it = 0; this->m_current_it < m_num_it;
626 this->m_current_it++) {
628 for (
unsigned int current_mode = 0; current_mode < m_modes;
633 gram_hadamard(current_mode);
635 #ifdef DISTNTF_VERBOSE 637 << this->m_local_ncp_factors.
factor(current_mode));
639 this->m_gathered_ncp_factors.
print();
640 PRINTROOT(
"global_grams::" << std::endl << this->global_gram);
642 this->ncp_local_mttkrp_t[current_mode].print();
649 #ifdef DISTNTF_VERBOSE 650 DISTPRINTINFO(
"it::" << this->m_current_it <<
"::mode::" << current_mode
654 if (m_compute_error && current_mode == this->m_modes - 1) {
657 update_factor_mode(current_mode,
factor.t());
659 if (m_compute_error) {
660 double temp_err = computeError(unnorm_factor, this->m_modes - 1);
661 this->m_rel_error = temp_err;
664 PRINTROOT(
"Iter::" << this->m_current_it <<
"::k::" 665 << this->m_low_rank_k <<
"::SIZE::" <<
MPI_SIZE 666 <<
"::algo::" << this->m_updalgo <<
"::time::" 667 << iter_time <<
"::relative_error::" << temp_err);
669 if (this->m_accelerated) {
674 PRINTROOT(
"completed it::" << this->m_current_it);
685 double computeError(
const MAT &unnorm_factor,
int mode) {
689 hadamard_all_grams = global_gram % factor_global_grams[mode];
690 VEC local_lambda = m_local_ncp_factors.
lambda();
691 ROWVEC temp_vec = local_lambda.t() * hadamard_all_grams;
692 double sq_norm_model = arma::dot(temp_vec, local_lambda);
696 double inner_product = arma::dot(ncp_local_mttkrp_t[mode], unnorm_factor);
700 double all_inner_product;
702 MPI_Allreduce(&inner_product, &all_inner_product, 1, MPI_DOUBLE, MPI_SUM,
707 #ifdef DISTNTF_VERBOSE 709 DISTPRINTINFO(
"local_inner_product::" << inner_product << std::endl);
711 << this->m_global_sqnorm_A <<
"::model_norm_sq::" << sq_norm_model
712 <<
"::global_inner_product::" << all_inner_product << std::endl);
715 this->m_global_sqnorm_A + sq_norm_model - 2 * all_inner_product;
716 if (squared_err < 0) {
717 PRINTROOT(
"computed error is negative due to round off");
719 << this->m_global_sqnorm_A
720 <<
"::model_norm_sq::" << sq_norm_model
721 <<
"::global_inner_product::" << all_inner_product
722 <<
"::squared_err::" << squared_err << std::endl);
724 return std::sqrt(std::abs(squared_err) / this->m_global_sqnorm_A);
731 double computeError(
const NCPFactors &new_factors_t,
const int mode) {
735 reset(new_factors_t,
true);
738 hadamard_all_grams = global_gram % factor_global_grams[mode];
739 VEC local_lambda = m_local_ncp_factors.
lambda();
741 arma::diagmat(local_lambda) * new_factors_t.factor(mode);
742 ROWVEC temp_vec = local_lambda.t() * hadamard_all_grams;
743 double sq_norm_model = arma::dot(temp_vec, local_lambda);
747 double inner_product = arma::dot(ncp_local_mttkrp_t[mode], unnorm_factor);
748 double all_inner_product;
749 MPI_Allreduce(&inner_product, &all_inner_product, 1, MPI_DOUBLE, MPI_SUM,
752 this->m_global_sqnorm_A + sq_norm_model - 2 * all_inner_product;
753 if (squared_err < 0) {
754 PRINTROOT(
"computed error is negative due to round off");
756 << this->m_global_sqnorm_A
757 <<
"::model_norm_sq::" << sq_norm_model
758 <<
"::global_inner_product::" << all_inner_product
759 <<
"::squared_err::" << squared_err << std::endl);
761 return std::sqrt(std::abs(squared_err) / this->m_global_sqnorm_A);
765 #endif // DISTNTF_DISTAUNTF_HPP_ size_t rank() const
Low Rank.
void randu(const int i_seed)
initializes the local tensor with the given seed.
const double err_communication_duration() const
void accelerated(const bool &set_acceleration)
Does the algorithm need acceleration?
Data is stored such that the unfolding is column major.
void compute_error(bool i_error)
Sets whether to compute the error or not.
int rank(const int *i_coords) const
Returns the rank of current MPI process given the cartesian coordinates.
const double nnls_duration() const
const double gram_duration() const
int startidx(int n, int p, int r)
Returns the start idx of the current rank r for a global dimension n across p processes.
const double duration() const
void print()
prints the entire NCPFactors including the factor matrices
const MPI_Comm & slice(const int i) const
Returns the slice communicator.
const double mttkrp_duration() const
void set_factor(const double *arma_factor_ptr, const long int mode)
bool is_stale_mttkrp(const int ¤t_mode) const
void trans(NCPFactors &factor_t)
Transposes the entire factor matrix.
const double allreduce_duration() const
void mttkrp(const int i_n, const MAT &i_krp, MAT *o_mttkrp) const
size of krp must be product of all dimensions leaving out nxk.
void reset(const NCPFactors &new_factors, bool trans=false)
This function will completely reset all the factors and the state of AUNTF.
void num_iterations(const int i_n)
Returns number of iterations.
#define DISTPRINTINFO(MSG)
int itersplit(int n, int p, int r)
The dimension a particular rank holds out of the global dimension n across p processes.
MAT krp_leave_out_one(const unsigned int i_n)
KRP leaving out the mode i_n.
const double multittv_duration() const
const double reducescatter_duration() const
VEC lambda() const
returns the lambda vector
const double compute_duration() const
void in_order_reuse_MTTKRP(long int n, double *out, bool colmajor, double &multittv_time, double &mttkrp_time)
UVEC proc_grids() const
Returns the process grid for which the communicators are setup.
void computeNTF()
The main computeNTF loop.
const double allgather_duration() const
void regularizers(const FVEC i_regs)
L1 and L2 Regularization for every mode.
void set(const int i_n, const MAT &i_factor)
Set the mode i_n with the given factor matrix.
int current_it() const
Returns the current outer iteration of the computeNTF.
void set_lambda(const VEC &new_lambda)
sets the lambda vector
double current_error() const
Returns the current error.
const double krp_duration() const
const MPI_Comm & fiber(const int i) const
Returns the fiber communicator.
const double trans_duration() const
int fiber_rank(int i) const
Returns the fiber rank on a particular fiber grid.
DistAUNTF(const Tensor &i_tensor, const int i_k, algotype i_algo, const UVEC &i_global_dims, const UVEC &i_local_dims, const UVEC &i_nls_sizes, const UVEC &i_nls_idxs, const NTFMPICommunicator &i_mpicomm)
UVEC dimensions() const
Returns a vector of dimensions on every mode.
const double err_compute_duration() const
int slice_rank(int i) const
Returns the slice rank on a particular slice grid.
double norm() const
returns the frobenius norm of the tensor
size_t modes() const
Returns the numbers of modes of the tensor.
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
const double communication_duration() const
void factor(int mode, double *factor_matrix)
Returns the factor matrix by collected it across all the processors.
MAT & factor(const int i_n) const
factor matrix of a mode i_n
#define TENSOR_LOCAL_NUMEL
void dim_tree(bool i_dim_tree)
MTTKRP can be computed with or without dimension trees.
VEC lambda()
Returns the lambda of the NCP factors.