Skip to content
Snippets Groups Projects
Forked from ogs / ogs
15141 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
NamedFunctionCaller.cpp 9.45 KiB
/**
 * \copyright
 * Copyright (c) 2012-2018, OpenGeoSys Community (http://www.opengeosys.org)
 *            Distributed under a Modified BSD License.
 *              See accompanying file LICENSE.txt or
 *              http://www.opengeosys.org/project/license
 *
 */

#include "NamedFunctionCaller.h"

#include <algorithm>

#include "BaseLib/Algorithm.h"

/// To understand the working of the algorithm the following two lemmata are
/// useful:
/// If a directed graph has a topological order, then it is a Directed Acyclic
/// Graph (DAG).
/// If a graph is a DAG, then it has a node with no incoming edges.
///
/// \note Negative indices in the graph adjacency list are leaves of the graph.
///
/// \param graph represents directed graph given as an adjacency list.
/// \return true if the given directed graph is acyclic.
bool hasTopologicalOrdering(std::vector<std::vector<int>> const& graph)
{
    // Number of incoming edges for each node of the graph
    std::vector<unsigned> number_incoming_edges(graph.size());

    // Count all incoming edges (i,j) for each node (i).
    for (auto const& node_i_adjacencies : graph)
    {
        for (int node_j : node_i_adjacencies)
        {
            if (node_j >= 0)    // ignore negative indices.
                ++number_incoming_edges[node_j];
        }
    }

    // Working queue: a set of nodes with no incoming edges.
    std::vector<std::size_t> q;
    for (std::size_t node_i = 0; node_i < number_incoming_edges.size();
         ++node_i)
    {
        if (number_incoming_edges[node_i] == 0)
        {
            q.push_back(node_i);
        }
    }


    auto num_dependent = number_incoming_edges.size() - q.size();

    while (!q.empty())
    {
        auto const node_i = q.back();
        q.pop_back();

        // Decrement counts for all edges (i,j).
        for (int node_j : graph[node_i])
        {
            if (node_j < 0)
                continue;  // ignore negative indices
            if (--number_incoming_edges[node_j] == 0)
            {  // Add a node without incoming edges to the queue.
                q.push_back(node_j);
                --num_dependent;
            }
        }
    }

    return num_dependent == 0;
}

enum class TraversePosition { StartNode, BetweenChildren, EndNode };

/*! Traverses the graph given by the adjacency list \c map_sink_source in a
 * depth-first manner starting at the node \c sink_fct.
 *
 * At the beginning and end of each node as well as between every two child
 * nodes the given \c callback is called.
 */
template <typename Callback>
void traverse(std::vector<std::vector<int>> const& map_sink_source,
              int sink_fct, Callback&& callback)
{
    assert(sink_fct < static_cast<int>(map_sink_source.size()));
    callback(sink_fct, TraversePosition::StartNode);

    if (sink_fct < 0)
        return;

    auto const& si_so = map_sink_source[sink_fct];
    std::size_t const num_args = si_so.size();
    for (std::size_t sink_arg = 0; sink_arg != num_args; ++sink_arg) {
        if (sink_arg != 0)
            callback(sink_fct, TraversePosition::BetweenChildren);
        traverse(map_sink_source, si_so[sink_arg], callback);
    }

    callback(sink_fct, TraversePosition::EndNode);
}

namespace NumLib
{
NamedFunctionCaller::NamedFunctionCaller(
    std::initializer_list<std::string> unbound_argument_names)
    : _uninitialized(-1 - unbound_argument_names.size())
{
    int idx = -1;
    for (auto arg : unbound_argument_names) {
        BaseLib::insertIfKeyUniqueElseError(
            _map_name_idx, arg, idx,
            "The name of the unbound argument is not unique.");
        --idx;
    }
}

void NamedFunctionCaller::addNamedFunction(NamedFunction&& fct)
{
    DBUG("Adding named function `%s'", fct.getName().c_str());

    BaseLib::insertIfKeyUniqueElseError(
        _map_name_idx, fct.getName(), _named_functions.size(),
        "The name of the function is not unique.");

    _map_sink_source.emplace_back(fct.getArgumentNames().size(),
                                  _uninitialized);
    _named_functions.push_back(std::move(fct));
}

void NamedFunctionCaller::plug(const std::string& sink_fct,
                               const std::string& sink_arg,
                               const std::string& source_fct)
{
    _deferred_plugs.push_back({sink_fct, sink_arg, source_fct});
}

void NamedFunctionCaller::applyPlugs()
{
    while (!_deferred_plugs.empty())
    {
        auto const& plug = _deferred_plugs.back();
        auto const& sink_fct = plug.sink_fct;
        auto const& sink_arg = plug.sink_arg;
        auto const& source = plug.source;

        auto const source_it = _map_name_idx.find(source);
        if (source_it == _map_name_idx.end())
        {
            OGS_FATAL("A function with the name `%s' has not been found.",
                      source.c_str());
        }
        auto const source_idx = source_it->second;

        auto const sink_it = _map_name_idx.find(sink_fct);
        if (sink_it == _map_name_idx.end())
        {
            OGS_FATAL("A function with the name `%s' has not been found.",
                      sink_fct.c_str());
        }
        auto const sink_fct_idx = sink_it->second;

        auto const& sink_args =
            _named_functions[sink_it->second].getArgumentNames();
        auto const sink_arg_it =
            std::find(sink_args.begin(), sink_args.end(), sink_arg);
        if (sink_arg_it == sink_args.end())
        {
            OGS_FATAL(
                "An argument with the name `%s' has not been found for the "
                "function `%s'.",
                sink_arg.c_str(), sink_fct.c_str());
        }
        std::size_t const sink_arg_idx =
            std::distance(sink_args.begin(), sink_arg_it);

        auto& sis_sos = _map_sink_source[sink_fct_idx];
        if (sis_sos[sink_arg_idx] != _uninitialized)
        {
            OGS_FATAL("A dependency for `%s'.`%s' has already been introduced.",
                      sink_fct.c_str(), sink_arg.c_str());
        }
        sis_sos[sink_arg_idx] = source_idx;
        if (!hasTopologicalOrdering(_map_sink_source))
        {
            OGS_FATAL(
                "The call graph being plugged together must be an acyclic "
                "graph. The added dependency for `%s'.`%s' introduces a cycle "
                "into the graph.",
                sink_fct.c_str(), sink_arg.c_str());
        }

        _deferred_plugs.pop_back();
    }
}

double NamedFunctionCaller::call(
    std::size_t function_idx,
    const std::vector<double>& unbound_arguments) const
{
    assert(_deferred_plugs.empty() &&
           "You must call applyPlugs() before this method!");

    auto const& sis_sos = _map_sink_source[function_idx];
    assert(sis_sos.size() ==
           _named_functions[function_idx].getArgumentNames().size());
    std::vector<double> fct_args(sis_sos.size());
    for (std::size_t sink=0; sink<sis_sos.size(); ++sink)
    {
        auto const source = sis_sos[sink];

        if (source >= 0) {
            fct_args[sink] = call(source, unbound_arguments);
        } else {
            assert(source != _uninitialized);
            fct_args[sink] = unbound_arguments[-source-1];
        }
    }

    return _named_functions[function_idx].call(fct_args);
}

std::string NamedFunctionCaller::getCallExpression(
    std::string const& function_name) const
{
    auto const fct_it = _map_name_idx.find(function_name);
    if (fct_it == _map_name_idx.end()) {
        OGS_FATAL("A function with the name `%s' has not been found.",
                  function_name.c_str());
    }

    std::string expr;
    auto callback = [&](int fct_idx, TraversePosition pos)
    {
        switch (pos) {
        case TraversePosition::StartNode:
        {
            if (fct_idx < 0) {
                auto it = std::find_if(
                    _map_name_idx.begin(), _map_name_idx.end(),
                    [fct_idx](std::pair<std::string, int> const& e) {
                        return e.second == fct_idx;
                    });
                if (it == _map_name_idx.end()) {
                    OGS_FATAL("The function index %i has not been found.", fct_idx);
                }
                expr += it->first;
            } else {
                expr += _named_functions[fct_idx].getName() + "(";
            }
            break;
        }
        case TraversePosition::BetweenChildren:
            expr += ", ";
            break;
        case TraversePosition::EndNode:
            expr += ")";
        }
    };

    traverse(_map_sink_source, fct_it->second, callback);
    DBUG("expression: %s", expr.c_str());
    return expr;
}

SpecificFunctionCaller
NamedFunctionCaller::getSpecificFunctionCaller(const std::string &function_name)
{
    auto const fct_it = _map_name_idx.find(function_name);
    if (fct_it == _map_name_idx.end()) {
        OGS_FATAL("A function with the name `%s' has not been found.",
                  function_name.c_str());
    }
    return SpecificFunctionCaller(fct_it->second, *this);
}

std::size_t NamedFunctionCaller::getNumberOfUnboundArguments() const
{
    return -_uninitialized - 1;
}

SpecificFunctionCaller::SpecificFunctionCaller(const std::size_t function_idx,
                                             const NamedFunctionCaller& caller)
    : _function_idx(function_idx), _caller(caller)
{
}

double SpecificFunctionCaller::call(
    const std::vector<double>& unbound_arguments) const
{
    return _caller.call(_function_idx, unbound_arguments);
}

std::size_t SpecificFunctionCaller::getNumberOfUnboundArguments() const
{
    return _caller.getNumberOfUnboundArguments();
}

} // namespace NumLib