14 #include <boost/filesystem.hpp>
15 #include <yaml-cpp/yaml.h>
27 std::time_t t = std::time(0);
28 std::tm* now = std::localtime(&t);
30 std::string datetime = std::to_string(now->tm_year + 1900) +
'-' +
31 std::to_string(now->tm_mon + 1) +
'-' + std::to_string(now->tm_mday);
32 datetime = datetime +
'-' + std::to_string(now->tm_hour) +
':' + std::to_string(now->tm_min) +
33 ':' + std::to_string(now->tm_sec);
35 this->timestamp_ = datetime;
40 YAML::Node config = YAML::LoadFile(this->filepath());
42 assert(config.Type() == YAML::NodeType::Map);
43 this->config_file = config;
44 }
catch (YAML::BadFile& exception) {
45 const std::string current_path = boost::filesystem::current_path().native();
47 std::runtime_error(
"Runtime error while trying to find config.yaml. The input file path " +
48 this->filepath() +
" could not be found from " + current_path));
53 T YamlParser::get(
const YAML::Node& node_map,
const std::initializer_list<const char*> indicies) {
54 YAML::Node node = YAML::Clone(node_map);
55 std::string faulty_variable;
56 for (
const char* index : indicies) {
57 faulty_variable = std::string(index);
59 if (!node.IsDefined()) {
60 throw std::runtime_error(
"The initialisation of " + faulty_variable +
" is wrong in config " +
65 if (node.size() > 1)
throw std::runtime_error(
"The node has more than one element.");
67 }
catch (std::exception e) {
68 throw std::runtime_error(
"There is a mismatch in the type conversion of " + faulty_variable +
69 " of " + this->filepath());
74 T
get_vector(
const YAML::Node& node_map,
const std::initializer_list<const char*> indicies) {
75 YAML::Node node = YAML::Clone(node_map);
76 std::string faulty_variable;
77 for (
const char* index : indicies) {
78 faulty_variable = std::string(index);
80 if (!node.IsDefined()) {
81 throw std::runtime_error(
"The initialisation of " + faulty_variable +
" is wrong.");
86 for (
int i = 0; i < node.size(); i++) output.push_back(node[i].as<
typename T::value_type>());
88 }
catch (std::exception e) {
89 throw std::runtime_error(
"There is a mismatch in the type conversion of " + faulty_variable);
99 this->version_ = get<std::string>(this->config_file, {
"Version"});
103 this->logging_ = get<std::string>(generalConfigNode, {
"logging"});
104 this->iterations_ = get<int>(generalConfigNode, {
"iterations"});
105 this->epsilonScaling_ = get<t_real>(generalConfigNode, {
"epsilonScaling"});
106 this->output_prefix_ = get<std::string>(generalConfigNode, {
"InputOutput",
"output_prefix"});
108 const std::string source_str =
109 get<std::string>(generalConfigNode, {
"InputOutput",
"input",
"source"});
110 if (source_str ==
"measurements") {
111 this->measurements_polarization_ =
stokes_string.at(get<std::string>(
112 generalConfigNode, {
"InputOutput",
"input",
"measurements",
"measurements_polarization"}));
114 get<std::string>(generalConfigNode, {
"InputOutput",
"input",
"measurements",
"warm_start"});
115 if (generalConfigNode[
"InputOutput"][
"input"][
"simulation"])
116 throw std::runtime_error(
117 "Expecting only the input measurements block in the configuration file. Please remove "
118 "simulation block!");
120 this->measurements_ = get_vector<std::vector<std::string>>(
121 generalConfigNode, {
"InputOutput",
"input",
"measurements",
"measurements_files"});
124 get<bool>(generalConfigNode, {
"InputOutput",
"input",
"measurements",
"w_term"});
126 PURIFY_LOW_LOG(
"W-term flag not set for input measurements; defaulting to true.");
127 this->w_term_ =
true;
130 const std::string units_measurement_str = get<std::string>(
131 generalConfigNode, {
"InputOutput",
"input",
"measurements",
"measurements_units"});
132 if (units_measurement_str ==
"lambda")
134 else if (units_measurement_str ==
"radians")
136 else if (units_measurement_str ==
"pixels")
139 throw std::runtime_error(
"Visibility units \"" + units_measurement_str +
140 "\" not recognised. Check your config file.");
141 this->measurements_sigma_ = get<t_real>(
142 generalConfigNode, {
"InputOutput",
"input",
"measurements",
"measurements_sigma"});
143 }
else if (source_str ==
"simulation") {
144 if (generalConfigNode[
"InputOutput"][
"input"][
"measurements"])
145 throw std::runtime_error(
146 "Expecting only the input simulation block in the configuration file. Please remove "
147 "measurements block!");
150 get<std::string>(generalConfigNode, {
"InputOutput",
"input",
"simulation",
"skymodel"});
151 this->signal_to_noise_ =
152 get<t_real>(generalConfigNode, {
"InputOutput",
"input",
"simulation",
"signal_to_noise"});
153 this->number_of_measurements_ = get<t_int>(
154 generalConfigNode, {
"InputOutput",
"input",
"simulation",
"number_of_measurements"});
155 this->sim_J_ = get<t_int>(generalConfigNode, {
"InputOutput",
"input",
"simulation",
"sim_J"});
156 this->w_rms_ = get<t_real>(generalConfigNode, {
"InputOutput",
"input",
"simulation",
"w_rms"});
157 this->measurements_ = get_vector<std::vector<std::string>>(
158 generalConfigNode, {
"InputOutput",
"input",
"simulation",
"coverage_files"});
160 const std::string units_measurement_str = get<std::string>(
161 generalConfigNode, {
"InputOutput",
"input",
"simulation",
"coverage_units"});
162 if (units_measurement_str ==
"lambda")
164 else if (units_measurement_str ==
"radians")
166 else if (units_measurement_str ==
"pixels")
169 throw std::runtime_error(
"Visibility units \"" + units_measurement_str +
170 "\" not recognised. Check your config file.");
172 throw std::runtime_error(
"Visibility source \"" + source_str +
173 "\" not recognised. Check your config file.");
177 this->kernel_ = get<std::string>(measureOperatorsNode, {
"kernel"});
178 this->oversampling_ = get<float>(measureOperatorsNode, {
"oversampling"});
179 this->powMethod_iter_ = get<int>(measureOperatorsNode, {
"powermethod",
"iters"});
180 this->powMethod_tolerance_ = get<float>(measureOperatorsNode, {
"powermethod",
"tolerance"});
181 this->eigenvector_real_ =
182 get<std::string>(measureOperatorsNode, {
"powermethod",
"eigenvector",
"real"});
183 this->eigenvector_imag_ =
184 get<std::string>(measureOperatorsNode, {
"powermethod",
"eigenvector",
"imag"});
185 this->cellsizex_ = get<double>(measureOperatorsNode, {
"pixelSize",
"cellsizex"});
186 this->cellsizey_ = get<double>(measureOperatorsNode, {
"pixelSize",
"cellsizey"});
187 this->width_ = get<int>(measureOperatorsNode, {
"imageSize",
"width"});
188 this->height_ = get<int>(measureOperatorsNode, {
"imageSize",
"height"});
189 this->Jx_ = get<unsigned int>(measureOperatorsNode, {
"J",
"Jx"});
190 this->Jy_ = get<unsigned int>(measureOperatorsNode, {
"J",
"Jy"});
191 this->Jw_ = get<unsigned int>(measureOperatorsNode, {
"J",
"Jw"});
192 this->gpu_ = get<bool>(measureOperatorsNode, {
"gpu"});
193 this->wprojection_ = get<bool>(measureOperatorsNode, {
"wide-field",
"wprojection"});
194 this->mpi_wstacking_ = get<bool>(measureOperatorsNode, {
"wide-field",
"mpi_wstacking"});
195 this->mpi_all_to_all_ = get<bool>(measureOperatorsNode, {
"wide-field",
"mpi_all_to_all"});
196 this->kmeans_iters_ = get<t_int>(measureOperatorsNode, {
"wide-field",
"kmeans_iterations"});
197 this->conjugate_w_ = get<bool>(measureOperatorsNode, {
"wide-field",
"conjugate_w"});
201 const std::string values_str = get<std::string>(SARANode, {
"wavelet_dict"});
202 this->wavelet_basis_ = this->
getWavelets(values_str);
203 this->wavelet_levels_ = get<t_int>(SARANode, {
"wavelet_levels"});
204 this->realValueConstraint_ = get<bool>(SARANode, {
"realValueConstraint"});
205 this->positiveValueConstraint_ = get<bool>(SARANode, {
"positiveValueConstraint"});
209 this->algorithm_ = get<std::string>(algorithmOptionsNode, {
"algorithm"});
210 if (this->algorithm_ ==
"padmm") {
211 this->epsilonConvergenceScaling_ =
212 get<t_real>(algorithmOptionsNode, {
"padmm",
"epsilonConvergenceScaling"});
214 get<std::string>(algorithmOptionsNode, {
"padmm",
"mpiAlgorithm"}));
215 this->relVarianceConvergence_ =
216 get<t_real>(algorithmOptionsNode, {
"padmm",
"relVarianceConvergence"});
217 this->update_iters_ = get<t_int>(algorithmOptionsNode, {
"padmm",
"stepsize",
"update_iters"});
218 this->update_tolerance_ =
219 get<t_real>(algorithmOptionsNode, {
"padmm",
"stepsize",
"update_tolerance"});
220 this->dualFBVarianceConvergence_ =
221 get<t_real>(algorithmOptionsNode, {
"padmm",
"dualFBVarianceConvergence"});
222 }
else if (this->algorithm_ ==
"fb" or this->algorithm_ ==
"fb_joint_map") {
224 get<std::string>(algorithmOptionsNode, {
"fb",
"mpiAlgorithm"}));
225 this->relVarianceConvergence_ =
226 get<t_real>(algorithmOptionsNode, {
"fb",
"relVarianceConvergence"});
227 this->stepsize_ = get<t_real>(algorithmOptionsNode, {
"fb",
"stepsize"});
228 this->regularisation_parameter_ =
229 get<t_real>(algorithmOptionsNode, {
"fb",
"regularisation_parameter"});
230 this->dualFBVarianceConvergence_ =
231 get<t_real>(algorithmOptionsNode, {
"fb",
"dualFBVarianceConvergence"});
234 get<std::string>(algorithmOptionsNode, {
"fb",
"nonDifferentiableFunctionType"}));
236 this->model_path_ = get<std::string>(algorithmOptionsNode, {
"fb",
"modelPath"});
240 get<std::string>(algorithmOptionsNode, {
"fb",
"differentiableFunctionType"}));
242 this->CRR_function_model_path_ =
243 get<std::string>(algorithmOptionsNode, {
"fb",
"CRR_function_model_path"});
244 this->CRR_gradient_model_path_ =
245 get<std::string>(algorithmOptionsNode, {
"fb",
"CRR_gradient_model_path"});
246 this->CRR_mu_ = get<t_real>(algorithmOptionsNode, {
"fb",
"CRR_mu"});
247 this->CRR_lambda_ = get<t_real>(algorithmOptionsNode, {
"fb",
"CRR_lambda"});
250 if (this->algorithm_ ==
"fb_joint_map") {
252 get<t_uint>(algorithmOptionsNode, {
"fb",
"joint_map_estimation",
"iters"});
253 this->jmap_relVarianceConvergence_ = get<t_real>(
254 algorithmOptionsNode, {
"fb",
"joint_map_estimation",
"relVarianceConvergence"});
255 this->jmap_objVarianceConvergence_ = get<t_real>(
256 algorithmOptionsNode, {
"fb",
"joint_map_estimation",
"objVarianceConvergence"});
258 get<t_real>(algorithmOptionsNode, {
"fb",
"joint_map_estimation",
"alpha"});
259 this->jmap_beta_ = get<t_real>(algorithmOptionsNode, {
"fb",
"joint_map_estimation",
"beta"});
261 }
else if (this->algorithm_ ==
"primaldual") {
262 this->epsilonConvergenceScaling_ =
263 get<t_real>(algorithmOptionsNode, {
"primaldual",
"epsilonConvergenceScaling"});
265 get<std::string>(algorithmOptionsNode, {
"primaldual",
"mpiAlgorithm"}));
266 this->relVarianceConvergence_ =
267 get<t_real>(algorithmOptionsNode, {
"primaldual",
"relVarianceConvergence"});
268 this->update_iters_ =
269 get<t_int>(algorithmOptionsNode, {
"primaldual",
"stepsize",
"update_iters"});
270 this->update_tolerance_ =
271 get<t_real>(algorithmOptionsNode, {
"primaldual",
"stepsize",
"update_tolerance"});
272 this->precondition_iters_ =
273 get<t_int>(algorithmOptionsNode, {
"primaldual",
"precondition_iters"});
275 throw std::runtime_error(
276 "Only padmm algorithm configured for now. Please fill the appropriate block in the "
277 "configuration file.");
288 std::vector<std::string> wavelets;
289 std::string value2add;
290 std::string input = values_str;
291 input.erase(std::remove_if(input.begin(), input.end(), [](
char x) { return std::isspace(x); }),
294 for (
int i = 0; i <= input.size(); i++) {
295 if ((i == input.size()) || (input.at(i) ==
',')) {
296 wavelets.push_back((value2add ==
"0") ?
"Dirac" : (
"DB" + value2add));
298 }
else if (input.at(i) ==
'.') {
302 const int n = ((i + 3) >= input.size()) ? 2 : ((input.at(i + 3) ==
',') ? 2 : 3);
303 const std::string final_value = input.substr(i + 2, n);
305 for (
int j = std::stoi(value2add); j <= std::stoi(final_value); j++)
306 wavelets.push_back((j == 0) ?
"Dirac" : (
"DB" + std::to_string(j)));
310 value2add = value2add + input.at(i);
319 std::size_t file_begin = this->filepath_.find_last_of(
"/");
320 if (file_begin == std::string::npos) file_begin = 0;
321 std::string file_path = filepath_.substr(0, file_begin);
322 std::string extension =
".yaml";
323 std::string base_file_name = this->filepath_.erase(this->filepath_.size() - extension.size());
325 base_file_name.substr((file_path.size() ? file_path.size() + 1 : 0), base_file_name.size());
327 boost::filesystem::path
const path(this->output_prefix_);
328 out_path = output_prefix_ +
"/output_" + std::string(this->timestamp());
330 std::string out_filename = out_path +
"/" + base_file_name +
"_save.yaml";
334 out << YAML::BeginMap;
335 out << YAML::Key <<
"Version";
336 out << this->version_;
337 out << YAML::Key <<
"GeneralConfiguration";
338 out << this->config_file[
"GeneralConfiguration"];
339 out << YAML::Key <<
"MeasureOperators";
340 out << this->config_file[
"MeasureOperators"];
341 out << YAML::Key <<
"SARA";
342 out << this->config_file[
"SARA"];
343 out << YAML::Key <<
"AlgorithmOptions";
344 out << this->config_file[
"AlgorithmOptions"];
347 std::ofstream output_file;
348 output_file.open(out_filename);
349 output_file << out.c_str();
void parseAndSetMeasureOperators(const YAML::Node &node)
void parseAndSetSARA(const YAML::Node &node)
void parseAndSetGeneralConfiguration(const YAML::Node &node)
void parseAndSetAlgorithmOptions(const YAML::Node &node)
void setParserVariablesFromYaml()
std::vector< std::string > getWavelets(const std::string &values_str)
YamlParser(const std::string &filepath)
#define PURIFY_LOW_LOG(...)
Low priority message.
const std::map< std::string, algo_distribution > algo_distribution_string
void mkdir_recursive(const std::string &path)
recursively create directories when they do not exist
const std::map< std::string, stokes > stokes_string
const std::map< std::string, diff_func_type > diff_type_string
T get_vector(const YAML::Node &node_map, const std::initializer_list< const char * > indicies)
const std::map< std::string, nondiff_func_type > nondiff_type_string