Forked from
ogs / ogs
15141 commits behind the upstream repository.
-
Dmitri Naumov authoredDmitri Naumov authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
NamedFunctionCaller.cpp 9.45 KiB
/**
* \copyright
* Copyright (c) 2012-2018, OpenGeoSys Community (http://www.opengeosys.org)
* Distributed under a Modified BSD License.
* See accompanying file LICENSE.txt or
* http://www.opengeosys.org/project/license
*
*/
#include "NamedFunctionCaller.h"
#include <algorithm>
#include "BaseLib/Algorithm.h"
/// To understand the working of the algorithm the following two lemmata are
/// useful:
/// If a directed graph has a topological order, then it is a Directed Acyclic
/// Graph (DAG).
/// If a graph is a DAG, then it has a node with no incoming edges.
///
/// \note Negative indices in the graph adjacency list are leaves of the graph.
///
/// \param graph represents directed graph given as an adjacency list.
/// \return true if the given directed graph is acyclic.
bool hasTopologicalOrdering(std::vector<std::vector<int>> const& graph)
{
// Number of incoming edges for each node of the graph
std::vector<unsigned> number_incoming_edges(graph.size());
// Count all incoming edges (i,j) for each node (i).
for (auto const& node_i_adjacencies : graph)
{
for (int node_j : node_i_adjacencies)
{
if (node_j >= 0) // ignore negative indices.
++number_incoming_edges[node_j];
}
}
// Working queue: a set of nodes with no incoming edges.
std::vector<std::size_t> q;
for (std::size_t node_i = 0; node_i < number_incoming_edges.size();
++node_i)
{
if (number_incoming_edges[node_i] == 0)
{
q.push_back(node_i);
}
}
auto num_dependent = number_incoming_edges.size() - q.size();
while (!q.empty())
{
auto const node_i = q.back();
q.pop_back();
// Decrement counts for all edges (i,j).
for (int node_j : graph[node_i])
{
if (node_j < 0)
continue; // ignore negative indices
if (--number_incoming_edges[node_j] == 0)
{ // Add a node without incoming edges to the queue.
q.push_back(node_j);
--num_dependent;
}
}
}
return num_dependent == 0;
}
enum class TraversePosition { StartNode, BetweenChildren, EndNode };
/*! Traverses the graph given by the adjacency list \c map_sink_source in a
* depth-first manner starting at the node \c sink_fct.
*
* At the beginning and end of each node as well as between every two child
* nodes the given \c callback is called.
*/
template <typename Callback>
void traverse(std::vector<std::vector<int>> const& map_sink_source,
int sink_fct, Callback&& callback)
{
assert(sink_fct < static_cast<int>(map_sink_source.size()));
callback(sink_fct, TraversePosition::StartNode);
if (sink_fct < 0)
return;
auto const& si_so = map_sink_source[sink_fct];
std::size_t const num_args = si_so.size();
for (std::size_t sink_arg = 0; sink_arg != num_args; ++sink_arg) {
if (sink_arg != 0)
callback(sink_fct, TraversePosition::BetweenChildren);
traverse(map_sink_source, si_so[sink_arg], callback);
}
callback(sink_fct, TraversePosition::EndNode);
}
namespace NumLib
{
NamedFunctionCaller::NamedFunctionCaller(
std::initializer_list<std::string> unbound_argument_names)
: _uninitialized(-1 - unbound_argument_names.size())
{
int idx = -1;
for (auto arg : unbound_argument_names) {
BaseLib::insertIfKeyUniqueElseError(
_map_name_idx, arg, idx,
"The name of the unbound argument is not unique.");
--idx;
}
}
void NamedFunctionCaller::addNamedFunction(NamedFunction&& fct)
{
DBUG("Adding named function `%s'", fct.getName().c_str());
BaseLib::insertIfKeyUniqueElseError(
_map_name_idx, fct.getName(), _named_functions.size(),
"The name of the function is not unique.");
_map_sink_source.emplace_back(fct.getArgumentNames().size(),
_uninitialized);
_named_functions.push_back(std::move(fct));
}
void NamedFunctionCaller::plug(const std::string& sink_fct,
const std::string& sink_arg,
const std::string& source_fct)
{
_deferred_plugs.push_back({sink_fct, sink_arg, source_fct});
}
void NamedFunctionCaller::applyPlugs()
{
while (!_deferred_plugs.empty())
{
auto const& plug = _deferred_plugs.back();
auto const& sink_fct = plug.sink_fct;
auto const& sink_arg = plug.sink_arg;
auto const& source = plug.source;
auto const source_it = _map_name_idx.find(source);
if (source_it == _map_name_idx.end())
{
OGS_FATAL("A function with the name `%s' has not been found.",
source.c_str());
}
auto const source_idx = source_it->second;
auto const sink_it = _map_name_idx.find(sink_fct);
if (sink_it == _map_name_idx.end())
{
OGS_FATAL("A function with the name `%s' has not been found.",
sink_fct.c_str());
}
auto const sink_fct_idx = sink_it->second;
auto const& sink_args =
_named_functions[sink_it->second].getArgumentNames();
auto const sink_arg_it =
std::find(sink_args.begin(), sink_args.end(), sink_arg);
if (sink_arg_it == sink_args.end())
{
OGS_FATAL(
"An argument with the name `%s' has not been found for the "
"function `%s'.",
sink_arg.c_str(), sink_fct.c_str());
}
std::size_t const sink_arg_idx =
std::distance(sink_args.begin(), sink_arg_it);
auto& sis_sos = _map_sink_source[sink_fct_idx];
if (sis_sos[sink_arg_idx] != _uninitialized)
{
OGS_FATAL("A dependency for `%s'.`%s' has already been introduced.",
sink_fct.c_str(), sink_arg.c_str());
}
sis_sos[sink_arg_idx] = source_idx;
if (!hasTopologicalOrdering(_map_sink_source))
{
OGS_FATAL(
"The call graph being plugged together must be an acyclic "
"graph. The added dependency for `%s'.`%s' introduces a cycle "
"into the graph.",
sink_fct.c_str(), sink_arg.c_str());
}
_deferred_plugs.pop_back();
}
}
double NamedFunctionCaller::call(
std::size_t function_idx,
const std::vector<double>& unbound_arguments) const
{
assert(_deferred_plugs.empty() &&
"You must call applyPlugs() before this method!");
auto const& sis_sos = _map_sink_source[function_idx];
assert(sis_sos.size() ==
_named_functions[function_idx].getArgumentNames().size());
std::vector<double> fct_args(sis_sos.size());
for (std::size_t sink=0; sink<sis_sos.size(); ++sink)
{
auto const source = sis_sos[sink];
if (source >= 0) {
fct_args[sink] = call(source, unbound_arguments);
} else {
assert(source != _uninitialized);
fct_args[sink] = unbound_arguments[-source-1];
}
}
return _named_functions[function_idx].call(fct_args);
}
std::string NamedFunctionCaller::getCallExpression(
std::string const& function_name) const
{
auto const fct_it = _map_name_idx.find(function_name);
if (fct_it == _map_name_idx.end()) {
OGS_FATAL("A function with the name `%s' has not been found.",
function_name.c_str());
}
std::string expr;
auto callback = [&](int fct_idx, TraversePosition pos)
{
switch (pos) {
case TraversePosition::StartNode:
{
if (fct_idx < 0) {
auto it = std::find_if(
_map_name_idx.begin(), _map_name_idx.end(),
[fct_idx](std::pair<std::string, int> const& e) {
return e.second == fct_idx;
});
if (it == _map_name_idx.end()) {
OGS_FATAL("The function index %i has not been found.", fct_idx);
}
expr += it->first;
} else {
expr += _named_functions[fct_idx].getName() + "(";
}
break;
}
case TraversePosition::BetweenChildren:
expr += ", ";
break;
case TraversePosition::EndNode:
expr += ")";
}
};
traverse(_map_sink_source, fct_it->second, callback);
DBUG("expression: %s", expr.c_str());
return expr;
}
SpecificFunctionCaller
NamedFunctionCaller::getSpecificFunctionCaller(const std::string &function_name)
{
auto const fct_it = _map_name_idx.find(function_name);
if (fct_it == _map_name_idx.end()) {
OGS_FATAL("A function with the name `%s' has not been found.",
function_name.c_str());
}
return SpecificFunctionCaller(fct_it->second, *this);
}
std::size_t NamedFunctionCaller::getNumberOfUnboundArguments() const
{
return -_uninitialized - 1;
}
SpecificFunctionCaller::SpecificFunctionCaller(const std::size_t function_idx,
const NamedFunctionCaller& caller)
: _function_idx(function_idx), _caller(caller)
{
}
double SpecificFunctionCaller::call(
const std::vector<double>& unbound_arguments) const
{
return _caller.call(_function_idx, unbound_arguments);
}
std::size_t SpecificFunctionCaller::getNumberOfUnboundArguments() const
{
return _caller.getNumberOfUnboundArguments();
}
} // namespace NumLib