Verified Commit c3e1e80e authored by Alberto Miranda's avatar Alberto Miranda ♨️
Browse files

Add ULT to master to listen for MPI messages

Also, enclose classes in `cargo` namespace
parent d6219fd0
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -106,7 +106,7 @@ main(int argc, char* argv[]) {
    try {
        if(world.rank() == 0) {

            master_server srv{cfg.progname, cfg.address, cfg.daemonize,
            cargo::master_server srv{cfg.progname, cfg.address, cfg.daemonize,
                                     fs::current_path()};

            if(cfg.output_file) {
+66 −17
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@
#include <fmt_formatters.hpp>
#include <boost/mpi.hpp>
#include <utility>
#include <boost/mpi/communicator.hpp>
#include "message.hpp"
#include "master.hpp"
#include "net/utilities.hpp"
@@ -37,13 +38,14 @@
#include "proto/rpc/response.hpp"

using namespace std::literals;
namespace mpi = boost::mpi;

namespace {

cargo::transfer_request_message
create_request_message(const cargo::dataset& input,
                       const cargo::dataset& output) {
cargo::transfer_request
make_request(const cargo::dataset& input, const cargo::dataset& output) {

    static std::uint64_t id = 0;
    cargo::transfer_type tx_type;

    if(input.supports_parallel_transfer()) {
@@ -54,33 +56,80 @@ create_request_message(const cargo::dataset& input,
        tx_type = cargo::sequential;
    }

    return cargo::transfer_request_message{input.path(), output.path(),
                                           tx_type};
    return cargo::transfer_request{id++, input.path(), output.path(), tx_type};
}

} // namespace

using namespace std::literals;

namespace cargo {

master_server::master_server(std::string name, std::string address,
                             bool daemonize, std::filesystem::path rundir,
                             std::optional<std::filesystem::path> pidfile)
    : server(std::move(name), std::move(address), daemonize, std::move(rundir),
             std::move(pidfile)),
      provider(m_network_engine, 0) {
      provider(m_network_engine, 0),
      m_mpi_listener_ess(thallium::xstream::create()),
      m_mpi_listener_ult(m_mpi_listener_ess->make_thread(
              [this]() { mpi_listener_ult(); })) {

#define EXPAND(rpc_name) #rpc_name##s, &master_server::rpc_name
    provider::define(EXPAND(ping));
    provider::define(EXPAND(transfer_datasets));

#undef EXPAND

    // ESs and ULTs need to be joined before the network engine is
    // actually finalized, and ~master_server() is too late for that.
    // The push_prefinalize_callback() and push_finalize_callback() functions
    // serve this purpose. The former is called before Mercury is finalized,
    // while the latter is called in between that and Argobots finalization.
    m_network_engine.push_finalize_callback([this]() {
        m_mpi_listener_ult->join();
        m_mpi_listener_ult = thallium::managed<thallium::thread>{};
        m_mpi_listener_ess->join();
        m_mpi_listener_ess = thallium::managed<thallium::xstream>{};
    });
}

master_server::~master_server() {}

void
master_server::mpi_listener_ult() {

    mpi::communicator world;

    while(!m_shutting_down) {

        auto msg = world.iprobe();

        if(!msg) {
            thallium::thread::self().sleep(m_network_engine, 150);
            continue;
        }

        switch(static_cast<cargo::tag>(msg->tag())) {
            case tag::status: {
                transfer_status st;
                world.recv(mpi::any_source, msg->tag(), st);
                LOGGER_INFO("[{}] Status received: {}", msg->source(), st);
                break;
            }

            default:
                LOGGER_WARN("[{}] Unexpected message tag: {}", msg->source(),
                            msg->tag());
                break;
        }
    }
}

#define RPC_NAME() ("ADM_"s + __FUNCTION__)

void
master_server::ping(const network::request& req) {

    using network::get_address;
    using network::rpc_info;
    using proto::generic_response;
@@ -89,7 +138,7 @@ master_server::ping(const network::request& req) {

    LOGGER_INFO("rpc {:>} body: {{}}", rpc);

    const auto resp = generic_response{rpc.id(), cargo::error_code{0}};
    const auto resp = generic_response{rpc.id(), error_code{0}};

    LOGGER_INFO("rpc {:<} body: {{retval: {}}}", rpc, resp.error_code());

@@ -98,9 +147,8 @@ master_server::ping(const network::request& req) {

void
master_server::transfer_datasets(const network::request& req,
                                 const std::vector<cargo::dataset>& sources,
                                 const std::vector<cargo::dataset>& targets) {

                                 const std::vector<dataset>& sources,
                                 const std::vector<dataset>& targets) {
    using network::get_address;
    using network::rpc_info;
    using proto::generic_response;
@@ -110,28 +158,29 @@ master_server::transfer_datasets(const network::request& req,
    LOGGER_INFO("rpc {:>} body: {{sources: {}, targets: {}}}", rpc, sources,
                targets);

    const auto resp = generic_response{rpc.id(), cargo::error_code{0}};
    const auto resp = generic_response{rpc.id(), error_code{0}};

    assert(sources.size() == targets.size());

    boost::mpi::communicator world;
    mpi::communicator world;
    for(auto i = 0u; i < sources.size(); ++i) {

        const auto& input_path = sources[i].path();
        const auto& output_path = targets[i].path();

        const auto m = ::create_request_message(sources[i], targets[i]);
        const auto m = ::make_request(sources[i], targets[i]);

        for(int rank = 1; rank < world.size(); ++rank) {
            world.send(rank, static_cast<int>(cargo::message_tags::transfer),
                       m);
            world.send(rank, static_cast<int>(tag::transfer), m);
        }
    }

    cargo::transfer tx{42};
    transfer tx{42};

    LOGGER_INFO("rpc {:<} body: {{retval: {}, transfer: {}}}", rpc,
                resp.error_code(), tx);

    req.respond(resp);
}

} // namespace cargo
+14 −6
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@
#include "net/server.hpp"
#include "cargo.hpp"

namespace cargo {

class master_server : public network::server,
                      public network::provider<master_server> {
public:
@@ -35,7 +37,12 @@ public:
                  std::filesystem::path rundir,
                  std::optional<std::filesystem::path> pidfile = {});

    ~master_server();

private:
    void
    mpi_listener_ult();

    void
    ping(const network::request& req);

@@ -43,13 +50,14 @@ private:
    transfer_datasets(const network::request& req,
                      const std::vector<cargo::dataset>& sources,
                      const std::vector<cargo::dataset>& targets);
};

namespace config {
struct settings;
} // namespace config
private:
    // Dedicated execution stream for the MPI listener ULT
    thallium::managed<thallium::xstream> m_mpi_listener_ess;
    // ULT for the MPI listener
    thallium::managed<thallium::thread> m_mpi_listener_ult;
};

void
master(const config::settings& cfg);
} // namespace cargo

#endif // CARGO_MASTER_HPP
+26 −20
Original line number Diff line number Diff line
@@ -33,19 +33,25 @@
namespace cargo {

enum transfer_type { parallel_read, parallel_write, sequential };
enum class message_tags { transfer, status, shutdown };
enum class tag : int { transfer, status };

class transfer_request_message {
class transfer_request {

    friend class boost::serialization::access;

public:
    transfer_request_message() = default;
    transfer_request() = default;

    transfer_request_message(const std::filesystem::path& input_path,
    transfer_request(std::uint64_t id, const std::filesystem::path& input_path,
                     const std::filesystem::path& output_path,
                     transfer_type type)
        : m_input_path(input_path), m_output_path(output_path), m_type(type) {}
        : m_id(id), m_input_path(input_path), m_output_path(output_path),
          m_type(type) {}

    std::uint64_t
    id() const {
        return m_id;
    }

    std::filesystem::path
    input_path() const {
@@ -68,24 +74,26 @@ private:
    serialize(Archive& ar, const unsigned int version) {
        (void) version;

        ar & m_id;
        ar & m_input_path;
        ar & m_output_path;
        ar & m_type;
    }

    std::uint64_t m_id;
    std::string m_input_path;
    std::string m_output_path;
    transfer_type m_type;
};

class transfer_status_message {
class transfer_status {

    friend class boost::serialization::access;

public:
    transfer_status_message() = default;
    transfer_status() = default;

    explicit transfer_status_message(std::uint64_t transfer_id)
    explicit transfer_status(std::uint64_t transfer_id)
        : m_transfer_id(transfer_id) {}

    std::uint64_t
@@ -108,12 +116,11 @@ private:
} // namespace cargo

template <>
struct fmt::formatter<cargo::transfer_request_message>
    : formatter<std::string_view> {
struct fmt::formatter<cargo::transfer_request> : formatter<std::string_view> {
    // parse is inherited from formatter<string_view>.
    template <typename FormatContext>
    auto
    format(const cargo::transfer_request_message& r, FormatContext& ctx) const {
    format(const cargo::transfer_request& r, FormatContext& ctx) const {
        const auto str = fmt::format("{{input_path: {}, output_path: {}}}",
                                     r.input_path(), r.output_path());
        return formatter<std::string_view>::format(str, ctx);
@@ -121,13 +128,12 @@ struct fmt::formatter<cargo::transfer_request_message>
};

template <>
struct fmt::formatter<cargo::transfer_status_message>
    : formatter<std::string_view> {
struct fmt::formatter<cargo::transfer_status> : formatter<std::string_view> {
    // parse is inherited from formatter<string_view>.
    template <typename FormatContext>
    auto
    format(const cargo::transfer_status_message& s, FormatContext& ctx) const {
        const auto str = fmt::format("{{transfer_id: {}}}", s.transfer_id());
    format(const cargo::transfer_status& s, FormatContext& ctx) const {
        const auto str = fmt::format("{{id: {}}}", s.transfer_id());
        return formatter<std::string_view>::format(str, ctx);
    }
};
+1 −0
Original line number Diff line number Diff line
@@ -379,6 +379,7 @@ server::teardown_and_exit() {

void
server::shutdown() {
    m_shutting_down = true;
    m_network_engine.finalize();
}

Loading