3 #ifndef DISTNTF_DISTNTFAOADMM_HPP_ 4 #define DISTNTF_DISTNTFAOADMM_HPP_ 22 double stop_iter_time;
35 MAT update(
const int mode) {
37 MAT updated_fac(this->m_local_ncp_factors.factor(mode));
38 MAT prev_fac = updated_fac;
44 if (m_nls_sizes[mode] > 0) {
45 alpha = arma::trace(this->global_gram) / this->m_local_ncp_factors.rank();
46 alpha = (alpha > 0) ? alpha : 0.01;
47 tempgram = this->global_gram;
48 tempgram.diag() += alpha;
49 L = arma::chol(tempgram,
"lower");
53 bool stop_iter =
false;
56 for (
int i = 0; i < admm_iter && !stop_iter; i++) {
57 if (m_nls_sizes[mode] > 0) {
58 prev_fac = updated_fac;
59 m_local_ncp_aux_t.
set(mode, m_local_ncp_aux.
factor(mode).t());
61 m_temp_local_ncp_aux_t.
set(
62 mode, arma::solve(arma::trimatl(L),
63 this->ncp_local_mttkrp_t[mode] +
64 (alpha * (updated_fac.t() +
65 m_local_ncp_aux_t.
factor(mode)))));
66 m_local_ncp_aux_t.
set(mode,
67 arma::solve(arma::trimatu(Lt),
68 m_temp_local_ncp_aux_t.
factor(mode)));
71 updated_fac = m_local_ncp_aux_t.
factor(mode).t();
74 updated_fac = updated_fac - m_local_ncp_aux.
factor(mode);
76 [](MAT::elem_type &val) { val = val > 0.0 ? val : 0.0; });
79 m_local_ncp_aux.
set(mode, m_local_ncp_aux.
factor(mode) + updated_fac -
80 m_local_ncp_aux_t.
factor(mode).t());
83 double local_facnorm = 0.0;
84 double local_dualnorm = 0.0;
87 if (m_nls_sizes[mode] > 0) {
88 local_facnorm = arma::norm(updated_fac,
"fro");
89 local_dualnorm = arma::norm(m_local_ncp_aux.
factor(mode),
"fro");
90 r = norm(updated_fac.t() - m_local_ncp_aux_t.
factor(mode),
"fro");
91 s = norm(updated_fac - prev_fac,
"fro");
95 local_facnorm *= local_facnorm;
96 double global_facnorm = 0.0;
97 MPI_Allreduce(&local_facnorm, &global_facnorm, 1, MPI_DOUBLE, MPI_SUM,
99 global_facnorm = sqrt(global_facnorm);
102 local_dualnorm *= local_dualnorm;
103 double global_dualnorm = 0.0;
104 MPI_Allreduce(&local_dualnorm, &global_dualnorm, 1, MPI_DOUBLE, MPI_SUM,
106 global_dualnorm = sqrt(global_dualnorm);
110 double global_r = 0.0;
111 MPI_Allreduce(&r, &global_r, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
112 global_r = sqrt(global_r);
115 double global_s = 0.0;
116 MPI_Allreduce(&s, &global_s, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
117 global_s = sqrt(global_s);
118 if (global_r < (tolerance * global_facnorm) &&
119 global_s < (tolerance * global_dualnorm))
124 m_local_ncp_aux.distributed_normalize(mode);
126 return updated_fac.t();
131 const UVEC &i_global_dims,
const UVEC &i_local_dims,
132 const UVEC &i_nls_sizes,
const UVEC &i_nls_idxs,
134 :
DistAUNTF(i_tensor, i_k, i_algo, i_global_dims, i_local_dims,
135 i_nls_sizes, i_nls_idxs, i_mpicomm),
136 m_local_ncp_aux(i_nls_sizes, i_k, false),
137 m_local_ncp_aux_t(i_nls_sizes, i_k, true),
138 m_temp_local_ncp_aux_t(i_nls_sizes, i_k, true) {
139 m_local_ncp_aux.
zeros();
140 m_local_ncp_aux_t.
zeros();
141 m_temp_local_ncp_aux_t.
zeros();
144 tempgram.zeros(i_k, i_k);
148 stop_iter_time = 0.0;
156 <<
"::stop_iter_time::" << stop_iter_time
157 <<
"::proj_time::" << proj_time
158 <<
"::solve_time::" << solve_time
159 <<
"::norm_time::" << norm_time);
166 #endif // DISTNTF_DISTNTFAOADMM_HPP_ Data is stored such that the unfolding is column major.
#define EPSILON_1EMINUS16
void set(const int i_n, const MAT &i_factor)
Set the mode i_n with the given factor matrix.
void zeros()
this is for reinitializing zeros across different processors.
ncp_factors contains the factors of the ncp every ith factor is of size n_i * k number of factors is ...
DistNTFAOADMM(const Tensor &i_tensor, const int i_k, algotype i_algo, const UVEC &i_global_dims, const UVEC &i_local_dims, const UVEC &i_nls_sizes, const UVEC &i_nls_idxs, const NTFMPICommunicator &i_mpicomm)
MAT & factor(const int i_n) const
factor matrix of a mode i_n