1 #ifndef ALGORITHM_FACTORY_H
2 #define ALGORITHM_FACTORY_H
4 #include "purify/config.h"
15 #include <sopt/mpi/communicator.h>
18 #include <sopt/differentiable_func.h>
19 #include <sopt/imaging_forward_backward.h>
20 #include <sopt/imaging_padmm.h>
21 #include <sopt/imaging_primal_dual.h>
22 #include <sopt/joint_map.h>
23 #include <sopt/l1_non_diff_function.h>
24 #include <sopt/non_differentiable_func.h>
25 #include <sopt/real_indicator.h>
26 #include <sopt/relative_variation.h>
27 #include <sopt/utilities.h>
28 #include <sopt/wavelets.h>
29 #include <sopt/wavelets/sara.h>
31 #include <sopt/tf_non_diff_function.h>
45 template <
class Algorithm,
class... ARGS>
48 template <
class Algorithm>
49 typename std::enable_if<
50 std::is_same<Algorithm, sopt::algorithm::ImagingProximalADMM<t_complex>>::value,
51 std::shared_ptr<Algorithm>>::type
53 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
55 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
58 const t_uint imsizex,
const t_uint sara_size,
const t_uint max_iterations = 500,
59 const bool real_constraint =
true,
const bool positive_constraint =
true,
60 const bool tight_frame =
false,
const t_real relative_variation = 1e-3,
61 const t_real l1_proximal_tolerance = 1e-2,
62 const t_uint maximum_proximal_iterations = 50,
63 const t_real residual_tolerance_scaling = 1) {
64 typedef typename Algorithm::Scalar t_scalar;
65 if (sara_size > 1 and tight_frame)
66 throw std::runtime_error(
67 "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
68 "one wavelet basis.");
70 auto epsilon = std::sqrt(2 * uv_data.
size() + 2 * std::sqrt(4 * uv_data.
size())) * sigma;
71 auto padmm = std::make_shared<Algorithm>(uv_data.
vis);
72 padmm->itermax(max_iterations)
73 .relative_variation(relative_variation)
74 .tight_frame(tight_frame)
75 .l1_proximal_tolerance(l1_proximal_tolerance)
77 .l1_proximal_itermax(maximum_proximal_iterations)
78 .l1_proximal_positivity_constraint(positive_constraint)
79 .l1_proximal_real_constraint(real_constraint)
80 .lagrange_update_scale(0.9)
90 ->regulariser_strength(
91 (wavelets->adjoint() * (measurements->adjoint() * uv_data.
vis).eval())
95 .l2ball_proximal_epsilon(epsilon)
96 .residual_tolerance(epsilon * residual_tolerance_scaling);
104 auto const comm = sopt::mpi::Communicator::World();
105 epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.
size()) +
106 2 * std::sqrt(4 * comm.all_sum_all(uv_data.
size()))) *
109 padmm->l2ball_proximal_communicator(comm);
115 auto const comm = sopt::mpi::Communicator::World();
116 epsilon = std::sqrt(2 * uv_data.
size() + 2 * std::sqrt(4 * uv_data.
size()) *
117 std::sqrt(comm.all_sum_all(4 * uv_data.
size())) /
118 comm.all_sum_all(std::sqrt(4 * uv_data.
size()))) *
124 throw std::runtime_error(
125 "Type of distributed proximal ADMM algorithm not recognised. You might not have compiled "
129 auto const comm = sopt::mpi::Communicator::World();
130 std::weak_ptr<Algorithm>
const padmm_weak(
padmm);
132 padmm->residual_tolerance(epsilon * residual_tolerance_scaling).l2ball_proximal_epsilon(epsilon);
134 padmm->regulariser_strength(comm.all_reduce(
139 padmm->l1_proximal_adjoint_space_comm(comm);
140 padmm->residual_convergence(
141 purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv, padmm_weak));
142 padmm->objective_convergence(
143 purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv, padmm_weak));
149 template <
class Algorithm>
150 typename std::enable_if<
151 std::is_same<Algorithm, sopt::algorithm::ImagingForwardBackward<t_complex>>::value,
152 std::shared_ptr<Algorithm>>::type
154 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
156 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
159 const t_real reg_parameter,
const t_uint imsizey,
const t_uint imsizex,
160 const t_uint sara_size,
const t_uint max_iterations = 500,
161 const bool real_constraint =
true,
const bool positive_constraint =
true,
162 const bool tight_frame =
false,
const t_real relative_variation = 1e-3,
163 const t_real l1_proximal_tolerance = 1e-2,
const t_uint maximum_proximal_iterations = 50,
164 const std::string model_path =
"",
166 std::shared_ptr<DifferentiableFunc<typename Algorithm::Scalar>> f_function =
nullptr) {
167 typedef typename Algorithm::Scalar t_scalar;
168 if (sara_size > 1 and tight_frame)
169 throw std::runtime_error(
170 "l1 proximal not consistent: You say you are using a tight frame, but you have more than "
171 "one wavelet basis.");
172 PURIFY_INFO(
"Constructing Forward Backward algorithm");
173 auto fb = std::make_shared<Algorithm>(uv_data.
vis);
174 fb->itermax(max_iterations)
175 .regulariser_strength(reg_parameter)
176 .sigma(sigma * std::sqrt(2))
178 .relative_variation(relative_variation)
179 .tight_frame(tight_frame)
182 if (f_function) fb->f_function(f_function);
183 std::shared_ptr<NonDifferentiableFunc<t_scalar>> g;
185 switch (g_proximal) {
189 auto l1_gp = std::make_shared<sopt::algorithm::L1GProximal<t_scalar>>(
false);
190 l1_gp->l1_proximal_tolerance(l1_proximal_tolerance)
192 .l1_proximal_itermax(maximum_proximal_iterations)
193 .l1_proximal_positivity_constraint(positive_constraint)
194 .l1_proximal_real_constraint(real_constraint)
198 auto const comm = sopt::mpi::Communicator::World();
199 l1_gp->l1_proximal_adjoint_space_comm(comm);
200 l1_gp->l1_proximal_direct_space_comm(comm);
209 g = std::make_shared<sopt::algorithm::TFGProximal<t_scalar>>(model_path);
212 throw std::runtime_error(
213 "Type TFGProximal not recognized because purify was built with onnxrt=off");
218 g = std::make_shared<sopt::algorithm::RealIndicator<t_scalar>>();
222 throw std::runtime_error(
"Type of g_proximal operator not recognised.");
234 auto const comm = sopt::mpi::Communicator::World();
235 fb->adjoint_space_comm(comm);
241 throw std::runtime_error(
242 "Type of distributed Forward Backward algorithm not recognised. You might not have "
250 template <
class Algorithm>
251 typename std::enable_if<
252 std::is_same<Algorithm, sopt::algorithm::ImagingPrimalDual<t_complex>>::value,
253 std::shared_ptr<Algorithm>>::type
256 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
258 std::shared_ptr<sopt::LinearTransform<Vector<typename Algorithm::Scalar>>
const>
const
261 const t_uint imsizex,
const t_uint sara_size,
const t_uint max_iterations = 500,
262 const bool real_constraint =
true,
const bool positive_constraint =
true,
263 const t_real relative_variation = 1e-3,
const t_real residual_tolerance_scaling = 1) {
264 typedef typename Algorithm::Scalar t_scalar;
266 auto epsilon = std::sqrt(2 * uv_data.
size() + 2 * std::sqrt(4 * uv_data.
size())) * sigma;
267 auto primaldual = std::make_shared<Algorithm>(uv_data.
vis);
268 primaldual->itermax(max_iterations)
269 .relative_variation(relative_variation)
270 .real_constraint(real_constraint)
271 .positivity_constraint(positive_constraint)
274 .tau(0.5 / (measurements->sq_norm() + 1))
284 ->regulariser_strength(
285 (wavelets->adjoint() * (measurements->adjoint() * uv_data.
vis).eval())
289 .l2ball_proximal_epsilon(epsilon)
290 .residual_tolerance(epsilon * residual_tolerance_scaling);
297 auto const comm = sopt::mpi::Communicator::World();
298 epsilon = std::sqrt(2 * comm.all_sum_all(uv_data.
size()) +
299 2 * std::sqrt(4 * comm.all_sum_all(uv_data.
size()))) *
302 primaldual->l2ball_proximal_communicator(comm);
307 auto const comm = sopt::mpi::Communicator::World();
308 epsilon = std::sqrt(2 * uv_data.
size() + 2 * std::sqrt(4 * uv_data.
size()) *
309 std::sqrt(comm.all_sum_all(4 * uv_data.
size())) /
310 comm.all_sum_all(std::sqrt(4 * uv_data.
size()))) *
314 auto const comm = sopt::mpi::Communicator::World();
315 epsilon = std::sqrt(2 * uv_data.
size() + 2 * std::sqrt(4 * uv_data.
size()) *
316 std::sqrt(comm.all_sum_all(4 * uv_data.
size())) /
317 comm.all_sum_all(std::sqrt(4 * uv_data.
size()))) *
321 std::shared_ptr<bool> random_measurement_update_ptr = std::make_shared<bool>(
true);
322 std::shared_ptr<bool> random_wavelet_update_ptr = std::make_shared<bool>(
true);
323 const t_int update_size = std::max<t_int>(std::floor(0.5 * comm.size()), 1);
325 comm, comm.size(), update_size, random_measurement_update_ptr,
"measurements");
327 comm, std::min<t_int>(comm.size(), std::floor(comm.all_sum_all(sara_size))),
328 std::min<t_int>(update_size, std::floor(0.5 * comm.all_sum_all(sara_size))),
329 random_measurement_update_ptr,
"wavelets");
331 primaldual->random_measurement_updater(random_measurement_updater)
332 .random_wavelet_updater(random_wavelet_updater)
333 .v_all_sum_all_comm(comm)
334 .u_all_sum_all_comm(comm);
338 throw std::runtime_error(
339 "Type of distributed primal dual algorithm not recognised. You might not have compiled "
343 auto const comm = sopt::mpi::Communicator::World();
344 std::weak_ptr<Algorithm>
const primaldual_weak(primaldual);
346 primaldual->residual_tolerance(epsilon * residual_tolerance_scaling)
347 .l2ball_proximal_epsilon(epsilon);
349 primaldual->regulariser_strength(comm.all_reduce(
354 primaldual->residual_convergence(
355 purify::factory::l2_convergence_factory<typename Algorithm::Scalar>(rel_conv,
357 primaldual->objective_convergence(
358 purify::factory::l1_convergence_factory<typename Algorithm::Scalar>(obj_conv,
364 template <
class Algorithm,
class... ARGS>
368 return padmm_factory<Algorithm>(std::forward<ARGS>(args)...);
382 throw std::runtime_error(
"Algorithm not implimented.");
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingProximalADMM< t_complex > >::value, std::shared_ptr< Algorithm > >::type padmm_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const t_real residual_tolerance_scaling=1)
return shared pointer to padmm object
const std::map< std::string, algo_distribution > algo_distribution_string
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingPrimalDual< t_complex > >::value, std::shared_ptr< Algorithm > >::type primaldual_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const t_real relative_variation=1e-3, const t_real residual_tolerance_scaling=1)
return shared pointer to primal dual object
std::shared_ptr< Algorithm > algorithm_factory(const factory::algorithm algo, ARGS &&...args)
return chosen algorithm given parameters
std::enable_if< std::is_same< Algorithm, sopt::algorithm::ImagingForwardBackward< t_complex > >::value, std::shared_ptr< Algorithm > >::type fb_factory(const algo_distribution dist, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &measurements, std::shared_ptr< sopt::LinearTransform< Vector< typename Algorithm::Scalar >> const > const &wavelets, const utilities::vis_params &uv_data, const t_real sigma, const t_real step_size, const t_real reg_parameter, const t_uint imsizey, const t_uint imsizex, const t_uint sara_size, const t_uint max_iterations=500, const bool real_constraint=true, const bool positive_constraint=true, const bool tight_frame=false, const t_real relative_variation=1e-3, const t_real l1_proximal_tolerance=1e-2, const t_uint maximum_proximal_iterations=50, const std::string model_path="", const nondiff_func_type g_proximal=nondiff_func_type::L1Norm, std::shared_ptr< DifferentiableFunc< typename Algorithm::Scalar >> f_function=nullptr)
return shared pointer to forward backward object
std::function< bool()> random_updater(const sopt::mpi::Communicator &comm, const t_int total, const t_int update_size, const std::shared_ptr< bool > update_pointer, const std::string &update_name)
t_real step_size(T const &vis, const std::shared_ptr< sopt::LinearTransform< T > const > &measurements, const std::shared_ptr< sopt::LinearTransform< T > const > &wavelets, const t_uint sara_size)
Calculate step size using MPI (does not include factor of 1e-3)
void padmm(const std::string &name, const Image< t_complex > &M31, const std::string &kernel, const t_int J, const utilities::vis_params &uv_data, const t_real sigma, const std::tuple< bool, t_real > &w_term)
t_uint size() const
return number of measurements