PURIFY
Next-generation radio interferometric imaging
uq_main.cc
Go to the documentation of this file.
1 #include <iostream>
2 #include <stdlib.h>
3 #include <string>
4 #include <tuple>
5 #include <vector>
7 #include "purify/pfitsio.h"
8 #include "purify/setup_utils.h"
9 #include "purify/utilities.h"
10 #include "purify/yaml-parser.h"
11 #include "yaml-cpp/yaml.h"
12 #include "sopt/differentiable_func.h"
13 #include "sopt/non_differentiable_func.h"
14 #include "sopt/objective_functions.h"
15 #include <sopt/l1_non_diff_function.h>
16 #include <sopt/l2_differentiable_func.h>
17 #include <sopt/real_indicator.h>
18 
19 using VectorC = sopt::Vector<std::complex<double>>;
20 
21 int main(int argc, char **argv) {
22  if (argc != 4) {
23  std::cout << "purify_UQ should be run using three additional arguments." << std::endl;
24  std::cout << "purify_UQ <config_path> <reference_image_path> <surrogate_image_path>"
25  << std::endl;
26  std::cout << "<config_path>: path to a .yaml config file specifying details of measurement "
27  "operator, wavelet operator, observations, and cost functions."
28  << std::endl;
29  std::cout << "<reference_image_path>: path to image file (.fits) which was output from running "
30  "purify on observed data."
31  << std::endl;
32  std::cout << "<surrogate_image_path>: path to modified image file (.fits) for feature analysis."
33  << std::endl;
34  std::cout << std::endl;
35  std::cout
36  << "For more information about the contents of the config file please consult the README."
37  << std::endl;
38  return 1;
39  }
40 
41  // Load and parse the config for parameters
42  const std::string config_path = argv[1];
43  const YAML::Node UQ_config = YAML::LoadFile(config_path);
44 
45  // Load the Reference and Surrogate images
46  const std::string ref_image_path = argv[2];
47  const std::string surrogate_image_path = argv[3];
48  const auto reference_image = purify::pfitsio::read2d(ref_image_path);
49  const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size());
50  const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path);
51  const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size());
52 
53  const uint imsize_x = reference_image.cols();
54  const uint imsize_y = reference_image.rows();
55 
56  std::unique_ptr<DifferentiableFunc<t_complex>> f;
57  std::unique_ptr<NonDifferentiableFunc<t_complex>> g;
58 
59  // Prepare operators and data using purify config
60  // If no purify config use basic version for now based on algo_factory test images
61  purify::utilities::vis_params measurement_data;
62  double regulariser_strength = 0;
63  std::shared_ptr<sopt::LinearTransform<VectorC>> measurement_operator;
64  std::shared_ptr<const sopt::LinearTransform<VectorC>> wavelet_operator;
65  std::vector<std::tuple<std::string, t_uint>> const sara{
66  std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
67  std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
68  std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
69  if (UQ_config["purify_config_file"]) {
70  YamlParser purify_config = YamlParser(UQ_config["purify_config_file"].as<std::string>());
71 
72  const auto [mop_algo, wop_algo, using_mpi] = selectOperators(purify_config);
73  auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
74  getInputData(purify_config, mop_algo, wop_algo, using_mpi);
75 
76  auto transform =
77  createMeasurementOperator(purify_config, mop_algo, wop_algo, using_mpi, image_index,
78  w_stacks, uv_data, measurement_op_eigen_vector);
79 
80  const waveletInfo wavelets = createWaveletOperator(purify_config, wop_algo);
81 
82  t_real const flux_scale = 1.;
83  uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;
84 
85  measurement_data = uv_data;
86  measurement_operator = transform;
87  wavelet_operator = wavelets.transform;
88 
89  // setup f and g based on config file
90  setupCostFunctions(purify_config, f, g, sigma, *measurement_operator);
91 
92  regulariser_strength = purify_config.regularisation_parameter();
93  } else {
94  const std::string measurements_path = UQ_config["measurements_path"].as<std::string>();
95  // Load the images and measurements
96  measurement_data = purify::utilities::read_visibility(measurements_path, false);
97 
98  // This is the measurement operator used in the test but this should probably be selectable
99  measurement_operator = purify::factory::measurement_operator_factory<sopt::Vector<t_complex>>(
101  imsize_x, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);
102 
103  wavelet_operator = purify::factory::wavelet_operator_factory<Vector<t_complex>>(
104  factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x);
105 
106  // default cost function
107  f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(
108  1, *measurement_operator); // what would a default sigma look like??
109  g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();
110 
111  try {
112  regulariser_strength = UQ_config["regulariser_strength"].as<double>();
113  } catch (...) {
114  std::cout
115  << "Regulariser strength not provided in UQ config, and no purify config was provided.\n";
116  std::cout << "Regulariser strength will be 0 by default." << std::endl;
117  }
118  }
119 
120  // Set up confidence and objective function params
121  double confidence;
122  double alpha;
123  if ((UQ_config["confidence_interval"]) && (UQ_config["alpha"])) {
124  std::cout << "Config should only contain one of 'confidence_interval' or 'alpha'." << std::endl;
125  return 1;
126  }
127  if (UQ_config["confidence_interval"]) {
128  confidence = UQ_config["confidence_interval"].as<double>();
129  alpha = 1 - confidence;
130  } else if (UQ_config["alpha"]) {
131  alpha = UQ_config["alpha"].as<double>();
132  confidence = 1 - alpha;
133  } else {
134  std::cout << "Config file must contain either 'confidence_interval' or 'alpha' as a parameter."
135  << std::endl;
136  return 1;
137  }
138 
139  if ((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) {
140  std::cout << "Surrogate and reference images have different dimensions. Aborting." << std::endl;
141  return 2;
142  }
143 
144  if (((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) {
145  std::cout << "Image size is not compatible with the measurement operator and data provided."
146  << std::endl;
147  return 3;
148  }
149 
150  // Calculate the posterior function for the reference image
151  // posterior = likelihood + prior
152  // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm)
153  // Prior = Sum(Psi^t * |x_i|) * regulariser_strength (L1 norm)
154  auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, regulariser_strength,
155  &f, &g](const VectorC &image) {
156  {
157  const auto residuals = (*measurement_operator * image) - measurement_data.vis;
158  auto A = f->function(image, measurement_data.vis, (*measurement_operator));
159  auto B = g->function(image);
160  return A + regulariser_strength * B;
161  }
162  };
163 
164  const double reference_posterior = Posterior(reference_vector);
165  const double surrogate_posterior = Posterior(surrogate_vector);
166 
167  // Threshold for surrogate image posterior to be within confidence limit
168  const double N = imsize_x * imsize_y;
169  const double tau = std::sqrt(16 * std::log(3 / alpha));
170  const double threshold = reference_posterior + tau * std::sqrt(N) + N;
171 
172  std::cout << "Uncertainty Quantification." << std::endl;
173  std::cout << "Reference Log Posterior = " << reference_posterior << std::endl;
174  std::cout << "Confidence interval = " << confidence << std::endl;
175  std::cout << "Log Posterior threshold = " << threshold << std::endl;
176  std::cout << "Surrogate Log Posterior = " << surrogate_posterior << std::endl;
177  std::cout << "Surrogate image is "
178  << ((surrogate_posterior <= threshold) ? "within the credible interval."
179  : "excluded by the credible interval.")
180  << std::endl;
181 
182  return 0;
183 }
const std::map< std::string, kernel > kernel_from_string
Definition: kernels.h:16
Image< t_complex > read2d(const std::string &fits_name)
Read image from fits file.
Definition: pfitsio.cc:109
utilities::vis_params read_visibility(const std::vector< std::string > &names, const bool w_term)
Read visibility files from name of vector.
std::shared_ptr< sopt::LinearTransform< Vector< t_complex > > > createMeasurementOperator(const YamlParser &params, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, const std::vector< t_int > &image_index, const std::vector< t_real > &w_stacks, const utilities::vis_params &uv_data, Vector< t_complex > &measurement_op_eigen_vector)
Definition: setup_utils.cc:250
waveletInfo createWaveletOperator(YamlParser &params, const factory::distributed_wavelet_operator &wop_algo)
Definition: setup_utils.cc:20
OperatorsInfo selectOperators(YamlParser &params)
Definition: setup_utils.cc:38
inputData getInputData(const YamlParser &params, const factory::distributed_measurement_operator mop_algo, const factory::distributed_wavelet_operator wop_algo, const bool using_mpi)
Definition: setup_utils.cc:66
void setupCostFunctions(const YamlParser &params, std::unique_ptr< DifferentiableFunc< t_complex >> &f, std::unique_ptr< NonDifferentiableFunc< t_complex >> &g, t_real sigma, sopt::LinearTransform< Vector< t_complex >> &Phi)
Definition: setup_utils.cc:312
Vector< t_complex > vis
Definition: uvw_utilities.h:22
std::shared_ptr< const sopt::LinearTransform< Eigen::VectorXcd > > transform
Definition: setup_utils.h:17
int main(int argc, char **argv)
Definition: uq_main.cc:21
sopt::Vector< std::complex< double > > VectorC
Definition: uq_main.cc:19