11 #include "yaml-cpp/yaml.h"
12 #include "sopt/differentiable_func.h"
13 #include "sopt/non_differentiable_func.h"
14 #include "sopt/objective_functions.h"
15 #include <sopt/l1_non_diff_function.h>
16 #include <sopt/l2_differentiable_func.h>
17 #include <sopt/real_indicator.h>
19 using VectorC = sopt::Vector<std::complex<double>>;
21 int main(
int argc,
char **argv) {
23 std::cout <<
"purify_UQ should be run using three additional arguments." << std::endl;
24 std::cout <<
"purify_UQ <config_path> <reference_image_path> <surrogate_image_path>"
26 std::cout <<
"<config_path>: path to a .yaml config file specifying details of measurement "
27 "operator, wavelet operator, observations, and cost functions."
29 std::cout <<
"<reference_image_path>: path to image file (.fits) which was output from running "
30 "purify on observed data."
32 std::cout <<
"<surrogate_image_path>: path to modified image file (.fits) for feature analysis."
34 std::cout << std::endl;
36 <<
"For more information about the contents of the config file please consult the README."
42 const std::string config_path = argv[1];
43 const YAML::Node UQ_config = YAML::LoadFile(config_path);
46 const std::string ref_image_path = argv[2];
47 const std::string surrogate_image_path = argv[3];
49 const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size());
51 const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size());
53 const uint imsize_x = reference_image.cols();
54 const uint imsize_y = reference_image.rows();
56 std::unique_ptr<DifferentiableFunc<t_complex>> f;
57 std::unique_ptr<NonDifferentiableFunc<t_complex>> g;
62 double regulariser_strength = 0;
63 std::shared_ptr<sopt::LinearTransform<VectorC>> measurement_operator;
64 std::shared_ptr<const sopt::LinearTransform<VectorC>> wavelet_operator;
65 std::vector<std::tuple<std::string, t_uint>>
const sara{
66 std::make_tuple(
"Dirac", 3u), std::make_tuple(
"DB1", 3u), std::make_tuple(
"DB2", 3u),
67 std::make_tuple(
"DB3", 3u), std::make_tuple(
"DB4", 3u), std::make_tuple(
"DB5", 3u),
68 std::make_tuple(
"DB6", 3u), std::make_tuple(
"DB7", 3u), std::make_tuple(
"DB8", 3u)};
69 if (UQ_config[
"purify_config_file"]) {
72 const auto [mop_algo, wop_algo, using_mpi] =
selectOperators(purify_config);
73 auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
74 getInputData(purify_config, mop_algo, wop_algo, using_mpi);
78 w_stacks, uv_data, measurement_op_eigen_vector);
82 t_real
const flux_scale = 1.;
83 uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;
85 measurement_data = uv_data;
86 measurement_operator = transform;
92 regulariser_strength = purify_config.regularisation_parameter();
94 const std::string measurements_path = UQ_config[
"measurements_path"].as<std::string>();
99 measurement_operator = purify::factory::measurement_operator_factory<sopt::Vector<t_complex>>(
103 wavelet_operator = purify::factory::wavelet_operator_factory<Vector<t_complex>>(
104 factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x);
107 f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(
108 1, *measurement_operator);
109 g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
112 regulariser_strength = UQ_config[
"regulariser_strength"].as<
double>();
115 <<
"Regulariser strength not provided in UQ config, and no purify config was provided.\n";
116 std::cout <<
"Regulariser strength will be 0 by default." << std::endl;
123 if ((UQ_config[
"confidence_interval"]) && (UQ_config[
"alpha"])) {
124 std::cout <<
"Config should only contain one of 'confidence_interval' or 'alpha'." << std::endl;
127 if (UQ_config[
"confidence_interval"]) {
128 confidence = UQ_config[
"confidence_interval"].as<
double>();
129 alpha = 1 - confidence;
130 }
else if (UQ_config[
"alpha"]) {
131 alpha = UQ_config[
"alpha"].as<
double>();
132 confidence = 1 - alpha;
134 std::cout <<
"Config file must contain either 'confidence_interval' or 'alpha' as a parameter."
139 if ((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) {
140 std::cout <<
"Surrogate and reference images have different dimensions. Aborting." << std::endl;
144 if (((*measurement_operator) * reference_vector).size() != measurement_data.
vis.size()) {
145 std::cout <<
"Image size is not compatible with the measurement operator and data provided."
154 auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, regulariser_strength,
155 &f, &g](
const VectorC &image) {
157 const auto residuals = (*measurement_operator * image) - measurement_data.
vis;
158 auto A = f->function(image, measurement_data.
vis, (*measurement_operator));
159 auto B = g->function(image);
160 return A + regulariser_strength * B;
164 const double reference_posterior = Posterior(reference_vector);
165 const double surrogate_posterior = Posterior(surrogate_vector);
168 const double N = imsize_x * imsize_y;
169 const double tau = std::sqrt(16 * std::log(3 / alpha));
170 const double threshold = reference_posterior + tau * std::sqrt(N) + N;
172 std::cout <<
"Uncertainty Quantification." << std::endl;
173 std::cout <<
"Reference Log Posterior = " << reference_posterior << std::endl;
174 std::cout <<
"Confidence interval = " << confidence << std::endl;
175 std::cout <<
"Log Posterior threshold = " << threshold << std::endl;
176 std::cout <<
"Surrogate Log Posterior = " << surrogate_posterior << std::endl;
177 std::cout <<
"Surrogate image is "
178 << ((surrogate_posterior <= threshold) ?
"within the credible interval."
179 :
"excluded by the credible interval.")
const std::map< std::string, kernel > kernel_from_string
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
waveletInfo createWaveletOperator(YamlParser ¶ms, const factory::distributed_wavelet_operator &wop_algo)
OperatorsInfo selectOperators(YamlParser ¶ms)
inputData getInputData(const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
void setupCostFunctions(const YamlParser ¶ms, std::unique_ptr< DifferentiableFunc< t_complex >> &f, std::unique_ptr< NonDifferentiableFunc< t_complex >> &g, t_real sigma, sopt::LinearTransform< Vector< t_complex >> &Phi)
std::shared_ptr< const sopt::LinearTransform< Eigen::VectorXcd > > transform
int main(int argc, char **argv)
sopt::Vector< std::complex< double > > VectorC