Verified Commit c4895fd6 authored by Alberto Miranda's avatar Alberto Miranda ♨️
Browse files

Refactor MPI messages

parent 8580b39e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -38,13 +38,13 @@ target_sources(
          worker.hpp
          env.hpp
          mpioxx.hpp
          message.hpp
          request.cpp
          request.hpp
          request_manager.cpp
          request_manager.hpp
          shared_mutex.hpp
          proto/rpc/response.hpp
          proto/mpi/message.hpp
)

target_include_directories(
+34 −24
Original line number Diff line number Diff line
@@ -31,11 +31,11 @@
#include <boost/mpi.hpp>
#include <utility>
#include <boost/mpi/communicator.hpp>
#include "message.hpp"
#include "master.hpp"
#include "net/utilities.hpp"
#include "net/request.hpp"
#include "proto/rpc/response.hpp"
#include "proto/mpi/message.hpp"
#include "request.hpp"

using namespace std::literals;
@@ -43,21 +43,25 @@ namespace mpi = boost::mpi;

namespace {

cargo::transfer_request
make_message(const cargo::dataset& input, const cargo::dataset& output) {

    static std::uint64_t id = 0;
    cargo::transfer_type tx_type;
std::tuple<int, cargo::transfer_message>
make_message(std::uint64_t tid, std::uint32_t seqno,
             const cargo::dataset& input, const cargo::dataset& output) {

    if(input.supports_parallel_transfer()) {
        tx_type = cargo::parallel_read;
    } else if(output.supports_parallel_transfer()) {
        tx_type = cargo::parallel_write;
    } else {
        tx_type = cargo::sequential;
        return std::make_tuple(static_cast<int>(cargo::tag::pread),
                               cargo::transfer_message{tid, seqno, input.path(),
                                                       output.path()});
    }

    if(output.supports_parallel_transfer()) {
        return std::make_tuple(static_cast<int>(cargo::tag::pwrite),
                               cargo::transfer_message{tid, seqno, input.path(),
                                                       output.path()});
    }

    return cargo::transfer_request{id++, input.path(), output.path(), tx_type};
    return std::make_tuple(
            static_cast<int>(cargo::tag::sequential),
            cargo::transfer_message{tid, seqno, input.path(), output.path()});
}

} // namespace
@@ -113,15 +117,21 @@ master_server::mpi_listener_ult() {

        switch(static_cast<cargo::tag>(msg->tag())) {
            case tag::status: {
                transfer_status st;
                world.recv(mpi::any_source, msg->tag(), st);
                LOGGER_INFO("[{}] Status received: {}", msg->source(), st);
                status_message m;
                world.recv(mpi::any_source, msg->tag(), m);
                LOGGER_INFO("msg => from: {} body: {{status: {}}}",
                            msg->source(), m);

                m_request_manager.update(m.tid(), m.seqno(), msg->source() - 1,
                                         m.error_code()
                                                 ? part_status::failed
                                                 : part_status::completed);
                break;
            }

            default:
                LOGGER_WARN("[{}] Unexpected message tag: {}", msg->source(),
                            msg->tag());
                LOGGER_WARN("msg => from: {} body: {{Unexpected tag: {}}}",
                            msg->source(), msg->tag());
                break;
        }
    }
@@ -161,7 +171,7 @@ master_server::transfer_datasets(const network::request& req,
    LOGGER_INFO("rpc {:>} body: {{sources: {}, targets: {}}}", rpc, sources,
                targets);

    m_request_manager.create(world.size() - 1, sources, targets)
    m_request_manager.create(sources.size(), world.size() - 1)
            .or_else([&](auto&& ec) {
                LOGGER_ERROR("Failed to create request: {}", ec);
                LOGGER_INFO("rpc {:<} body: {{retval: {}}}", rpc, ec);
@@ -175,16 +185,16 @@ master_server::transfer_datasets(const network::request& req,
                    const auto& d = targets[i];

                    for(std::size_t rank = 1; rank <= r.nworkers(); ++rank) {
                        world.send(static_cast<int>(rank),
                                   static_cast<int>(tag::transfer),
                                   make_message(s, d));
                        const auto [t, m] = make_message(r.tid(), i, s, d);
                        LOGGER_INFO("msg <= to: {} body: {}", rank, m);
                        world.send(static_cast<int>(rank), t, m);
                    }
                }

                LOGGER_INFO("rpc {:<} body: {{retval: {}, transfer_id: {}}}",
                            rpc, error_code::success, r.id());
                LOGGER_INFO("rpc {:<} body: {{retval: {}, tid: {}}}", rpc,
                            error_code::success, r.tid());
                req.respond(response_with_id{rpc.id(), error_code::success,
                                             r.id()});
                                             r.tid()});
            });
}

+58 −42
Original line number Diff line number Diff line
@@ -22,83 +22,93 @@
 * SPDX-License-Identifier: GPL-3.0-or-later
 *****************************************************************************/

#ifndef CARGO_MESSAGE_HPP
#define CARGO_MESSAGE_HPP
#ifndef CARGO_PROTO_MPI_MESSAGE_HPP
#define CARGO_PROTO_MPI_MESSAGE_HPP

#include <fmt/format.h>
#include <filesystem>
#include <boost/archive/binary_oarchive.hpp>
#include <utility>
#include "cargo/error.hpp"

namespace cargo {

enum transfer_type { parallel_read, parallel_write, sequential };
enum class tag : int { transfer, status };
enum class tag : int { pread, pwrite, sequential, status };

class transfer_request {
class transfer_message {

    friend class boost::serialization::access;

public:
    transfer_request() = default;
    transfer_message() = default;

    transfer_request(std::uint64_t id, const std::filesystem::path& input_path,
                     const std::filesystem::path& output_path,
                     transfer_type type)
        : m_id(id), m_input_path(input_path), m_output_path(output_path),
          m_type(type) {}
    transfer_message(std::uint64_t tid, std::uint32_t seqno,
                     std::string input_path, std::string output_path)
        : m_tid(tid), m_seqno(seqno), m_input_path(std::move(input_path)),
          m_output_path(std::move(output_path)) {}

    std::uint64_t
    id() const {
        return m_id;
    [[nodiscard]] std::uint64_t
    tid() const {
        return m_tid;
    }

    std::filesystem::path
    [[nodiscard]] std::uint32_t
    seqno() const {
        return m_seqno;
    }

    [[nodiscard]] const std::string&
    input_path() const {
        return m_input_path;
    }

    std::filesystem::path
    [[nodiscard]] const std::string&
    output_path() const {
        return m_output_path;
    }

    transfer_type
    type() const {
        return m_type;
    }

private:
    template <class Archive>
    void
    serialize(Archive& ar, const unsigned int version) {
        (void) version;

        ar & m_id;
        ar & m_tid;
        ar & m_seqno;
        ar & m_input_path;
        ar & m_output_path;
        ar & m_type;
    }

    std::uint64_t m_id;
    std::uint64_t m_tid{};
    std::uint32_t m_seqno{};
    std::string m_input_path;
    std::string m_output_path;
    transfer_type m_type;
};

class transfer_status {
class status_message {

    friend class boost::serialization::access;

public:
    transfer_status() = default;
    status_message() = default;

    explicit transfer_status(std::uint64_t transfer_id)
        : m_transfer_id(transfer_id) {}
    status_message(std::uint64_t tid, std::uint32_t seqno,
                   cargo::error_code error_code)
        : m_tid(tid), m_seqno(seqno), m_error_code(error_code) {}

    [[nodiscard]] std::uint64_t
    tid() const {
        return m_tid;
    }

    [[nodiscard]] std::uint32_t
    seqno() const {
        return m_seqno;
    }

    std::uint64_t
    transfer_id() const {
        return m_transfer_id;
    [[nodiscard]] cargo::error_code
    error_code() const {
        return m_error_code;
    }

private:
@@ -107,35 +117,41 @@ private:
    serialize(Archive& ar, const unsigned int version) {
        (void) version;

        ar& m_transfer_id;
        ar & m_tid;
        ar & m_seqno;
        ar & m_error_code;
    }

    std::uint64_t m_transfer_id{};
    std::uint64_t m_tid{};
    std::uint32_t m_seqno{};
    cargo::error_code m_error_code{};
};

} // namespace cargo

template <>
struct fmt::formatter<cargo::transfer_request> : formatter<std::string_view> {
struct fmt::formatter<cargo::transfer_message> : formatter<std::string_view> {
    // parse is inherited from formatter<string_view>.
    template <typename FormatContext>
    auto
    format(const cargo::transfer_request& r, FormatContext& ctx) const {
        const auto str = fmt::format("{{input_path: {}, output_path: {}}}",
                                     r.input_path(), r.output_path());
    format(const cargo::transfer_message& r, FormatContext& ctx) const {
        const auto str = fmt::format(
                "{{tid: {}, seqno: {}, input_path: {}, output_path: {}}}",
                r.tid(), r.seqno(), r.input_path(), r.output_path());
        return formatter<std::string_view>::format(str, ctx);
    }
};

template <>
struct fmt::formatter<cargo::transfer_status> : formatter<std::string_view> {
struct fmt::formatter<cargo::status_message> : formatter<std::string_view> {
    // parse is inherited from formatter<string_view>.
    template <typename FormatContext>
    auto
    format(const cargo::transfer_status& s, FormatContext& ctx) const {
        const auto str = fmt::format("{{id: {}}}", s.transfer_id());
    format(const cargo::status_message& s, FormatContext& ctx) const {
        const auto str = fmt::format("{{tid: {}, seqno: {}, error_code: {}}}",
                                     s.tid(), s.seqno(), s.error_code());
        return formatter<std::string_view>::format(str, ctx);
    }
};

#endif // CARGO_MESSAGE_HPP
#endif // CARGO_PROTO_MPI_MESSAGE_HPP
+4 −3
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@
#include <posix_file/file.hpp>
#include <posix_file/views.hpp>
#include "worker.hpp"
#include "message.hpp"
#include "proto/mpi/message.hpp"
#include "mpioxx.hpp"

namespace mpi = boost::mpi;
@@ -337,7 +337,7 @@ worker::run() {

        switch(static_cast<tag>(msg->tag())) {
            case tag::transfer: {
                transfer_request m;
                transfer_message m;
                world.recv(0, msg->tag(), m);
                LOGGER_DEBUG("Transfer request received!: {}", m);

@@ -358,7 +358,8 @@ worker::run() {
                        world.rank(), workers.rank());

                world.send(msg->source(), static_cast<int>(tag::status),
                           transfer_status{m.id()});
                           status_message{m.tid(), m.seqno(),
                                          error_code::success});

                break;
            }