diff --git a/.gitignore b/.gitignore index 92bfd738..36d924a1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ +*.idx + # build directory build/ +bin/ # vim temporary files *.swp diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..fcc67474 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,107 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What this is + +`mgconsole` is a C++20 command-line client for the [Memgraph](https://memgraph.com) graph +database. It talks the Bolt protocol via the bundled `mgclient` library, reads Cypher from +either an interactive prompt (replxx) or stdin, and prints results as tabular / CSV / cypherl. + +## Build + +Dependencies are fetched and built from source via CMake `ExternalProject` (`gflags` pinned to +`70c01a6`, `mgclient` pinned to `v1.5.0`); `OpenSSL` and a C++20 compiler must be present. The +first configure/build is slow because of these external builds. + +```bash +cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release . # macOS: add -DOPENSSL_ROOT_DIR="$(brew --prefix openssl)" +cmake --build build +cmake --install build # installs to /usr (Linux) or /usr/local (macOS) by default +``` + +The binary lands at `build/src/mgconsole`. `compile_commands.json` is emitted into `build/`. +`-Wall -Wextra -pedantic -Werror` is enabled — warnings break the build. + +**Static / release build** (matches what ships): `./build-generic-linux.sh` builds inside the +`memgraph/mgbuild` Docker image with the Memgraph toolchain and `-DMGCONSOLE_STATIC_SSL=ON`, +producing `build/generic/mgconsole`. + +## Test + +Tests are wired into CTest and require a Memgraph instance (binary or Docker image) to run +against — they spin one up, exercise the client, then tear it down. + +```bash +# Configure tests against a Docker Memgraph (no local binary needed): +cmake -B build -G Ninja -DMEMGRAPH_USE_DOCKER=ON -DMEMGRAPH_DOCKER_IMAGE=memgraph/memgraph:latest +cmake --build build +ctest --verbose --test-dir build # runs all tests + +ctest --test-dir build -R parameters-unit-test # single unit test (no DB needed) +ctest --test-dir build -R mgconsole-test # end-to-end I/O tests (plaintext) +ctest --test-dir build -R mgconsole-secure-test # same, over SSL +``` + +To point at a local Memgraph binary instead of Docker, configure with +`-DMEMGRAPH_PATH=/path/to/memgraph` and leave `MEMGRAPH_USE_DOCKER=OFF` (default). + +Two test kinds: +- **Unit** (`tests/unit/`): `parameters_test.cpp` links the standalone `params` library and runs + without a database. +- **End-to-end** (`tests/input_output/`): driven by `run-tests.sh`. For every file in `input/`, + it runs the client once per output format and diffs stdout against the matching golden file in + `output_tabular/`, `output_csv/`, etc. **When you change output formatting or add an input + case, regenerate/add the corresponding golden file in every `output_*` directory** — the test + matrix is the cross-product of inputs × formats. + +## Architecture + +`main.cpp` parses gflags, sets up signal handlers, and dispatches to exactly one of four "modes" +based on whether stdin is a TTY and the `--import-mode` flag. Each mode lives in its own +translation unit under the `mode::` namespace and implements a `Run(...)` entry point: + +- **`mode::interactive`** (`interactive.cpp`) — chosen when stdin is a TTY. The replxx REPL loop: + reads a query, handles `:param`/`:params`/`:help`/`:quit` commands, executes via `mgclient`, + prints results, manages history, and reconnects (3 retries) on fatal connection errors. +- **`mode::serial_import`** (`serial_import.cpp`) — default non-interactive (piped) path. Reads + queries one at a time and executes them in order. This is the `DUMP DATABASE | mgconsole` path. +- **`mode::batch_import`** (`batch_import.cpp`) — `--import-mode=batched-parallel`. EXPERIMENTAL. + Classifies each query (via `QueryInfo`) as pre/vertex/edge/post, groups vertex and edge queries + into batches, and executes batches concurrently across a thread pool of `--workers-number` + Bolt sessions, with exponential backoff + retry on failure. Vertices are flushed before edges + because edges depend on existing vertices. Query classification is heuristic — that's why this + mode is experimental. +- **`mode::parsing`** (`parsing.cpp`) — `--import-mode=parser`. Parses queries and prints + `QueryInfo` stats without touching the database. + +All four modes funnel through shared primitives in `src/utils/`, organized by namespace within +`utils.hpp`/`utils.cpp` (a large ~1300-line file): + +- **`query::`** — `GetQuery()` reads and accumulates a complete (`;`-terminated, possibly + multi-line) query from the input source, optionally producing `QueryInfo` (the has_create / + has_match / has_merge / ... flags that drive batch classification). `ExecuteQuery()` / + `ExecuteBatch()` run against an `mg_session`. `QueryResult` carries records, header, timing, + notifications, and execution stats. +- **`console::`** — TTY detection, line reading, `Echo*` output helpers (failure/info/stats). +- **`format::`** — `CsvOptions` and `OutputOptions`; `Output()` renders a result set as tabular, + CSV, or cypherl. +- **`utils::bolt`** (`bolt.hpp`/`bolt.cpp`) — `Config` struct and `MakeBoltSession()`, the single + place sessions are created (direct or routing connection). + +`src/parameters.{hpp,cpp}` is deliberately a **separate static library (`params`)** depending +only on `mgclient`, so the `:param` parsing/storage logic can be unit-tested in isolation. Don't +add heavier dependencies to it. + +Concurrency support for batch mode is custom and lives in `utils/`: `thread_pool`, `future` +(promise/future with notification hooks), `notifier`, and `synchronized`. `mg_memory.hpp` wraps +raw `mgclient` C pointers in RAII unique-ptr types (`MgSessionPtr`, `MgValuePtr`, etc.) — use +these rather than managing `mg_*` lifetimes by hand. + +## Conventions + +- Every source file carries the GPLv3 license header — copy it onto new files. +- Formatting is enforced by `.clang-format` (Google base, 120 col). Run `clang-format` before + committing. +- `MG_ASSERT` / `MG_FAIL` (`utils/assert.hpp`) are the assertion/abort macros. +- `date.hpp` is a large vendored third-party header (Howard Hinnant's date lib) — don't edit it. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5fc08611..a4f92bca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,7 +90,7 @@ get_filename_component(MGCLIENT_LIB_PATH "${MGCLIENT_PREFIX}/${MG_INSTALL_LIB_DI ExternalProject_Add(mgclient-proj PREFIX ${MGCLIENT_PREFIX} GIT_REPOSITORY https://github.com/memgraph/mgclient.git - GIT_TAG v1.5.0 + GIT_TAG v1.6.0 CMAKE_ARGS "-DCMAKE_INSTALL_PREFIX=" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" diff --git a/src/batch_import.cpp b/src/batch_import.cpp index c0a0fb3d..ccf1f836 100644 --- a/src/batch_import.cpp +++ b/src/batch_import.cpp @@ -15,6 +15,7 @@ #include "batch_import.hpp" +#include #include #include #include @@ -156,10 +157,19 @@ struct BatchExecutionContext { : batch_size(batch_size), max_batches(max_batches), max_concurrent_executions(max_concurrent_executions), - thread_pool(max_concurrent_executions) { + thread_pool(max_concurrent_executions), + config(bolt_config) { sessions.reserve(max_concurrent_executions); for (uint64_t thread_i = 0; thread_i < max_concurrent_executions; ++thread_i) { - sessions[thread_i] = MakeBoltSession(bolt_config); + // Each worker connects to the resolved main. In routed mode every worker + // re-routes independently at creation, so the coordinator may receive up + // to `workers_number` (e.g. 32) ROUTE calls at startup. The TTL expiry of + // the last route determines when sessions are refreshed (see Run). + if (config.routed_connection) { + sessions.push_back(utils::bolt::MakeRoutedBoltSession(config, &expiry)); + } else { + sessions.push_back(MakeBoltSession(config)); + } if (!sessions[thread_i].get()) { MG_FAIL("a session uninitialized"); } @@ -175,6 +185,10 @@ struct BatchExecutionContext { utils::ThreadPool thread_pool{max_concurrent_executions}; utils::Notifier notifier; std::vector sessions; + /// Connection config; kept so workers can be reconnected / re-routed. + utils::bolt::Config config; + /// When the current routing table's TTL expires (routed mode only). + std::chrono::steady_clock::time_point expiry{}; }; Batches FetchBatches(BatchExecutionContext &execution_context) { @@ -201,7 +215,7 @@ Batches FetchBatches(BatchExecutionContext &execution_context) { void ExecuteSerial(const std::vector &queries, BatchExecutionContext &context) { for (const auto &query : queries) { try { - query::ExecuteQuery(context.sessions[0].get(), query.query); + query::ExecuteQuery(context.sessions[0].get(), query.query, nullptr, context.config.db); } catch (const utils::ClientQueryException &e) { console::EchoFailure("Client received query exception", e.what()); MG_FAIL("Unable to ExecuteSerial"); @@ -248,7 +262,7 @@ uint64_t ExecuteBatchesParallel(std::vector &batches, BatchExecuti if (batch.backoff > 1) { std::this_thread::sleep_for(std::chrono::milliseconds(batch.backoff)); } - auto ret = query::ExecuteBatch(execution_context.sessions[thread_i].get(), batch); + auto ret = query::ExecuteBatch(execution_context.sessions[thread_i].get(), batch, bolt_config.db); if (ret.is_executed) { batch.is_executed = true; executed_batches++; @@ -265,7 +279,11 @@ uint64_t ExecuteBatchesParallel(std::vector &batches, BatchExecuti promise->Fill(false); } if (mg_session_status(execution_context.sessions[thread_i].get()) == MG_SESSION_BAD) { - execution_context.sessions[thread_i] = MakeBoltSession(bolt_config); + if (bolt_config.routed_connection) { + execution_context.sessions[thread_i] = utils::bolt::MakeRoutedBoltSession(bolt_config, nullptr); + } else { + execution_context.sessions[thread_i] = MakeBoltSession(bolt_config); + } } }); f_execs.insert_or_assign(thread_i, std::move(future)); @@ -287,6 +305,19 @@ int Run(const utils::bolt::Config &bolt_config, int batch_size, int workers_numb // (workers_number). BatchExecutionContext execution_context(batch_size, workers_number, workers_number, bolt_config); while (true) { + // Round boundary checkpoint (single-threaded): in routed mode, if the + // routing table's TTL has expired, re-route once to find the (possibly new) + // main and reconnect every worker session to it. This keeps proactive TTL + // refresh out of the parallel rounds where it would be race-prone. + if (bolt_config.routed_connection && std::chrono::steady_clock::now() >= execution_context.expiry) { + for (uint64_t thread_i = 0; thread_i < execution_context.max_concurrent_executions; ++thread_i) { + execution_context.sessions[thread_i] = + utils::bolt::MakeRoutedBoltSession(bolt_config, &execution_context.expiry); + if (!execution_context.sessions[thread_i].get()) { + MG_FAIL("failed to re-route a worker session after TTL expiry"); + } + } + } auto batches = FetchBatches(execution_context); if (batches.Empty()) { break; diff --git a/src/interactive.cpp b/src/interactive.cpp index c533d09a..c5d411f2 100644 --- a/src/interactive.cpp +++ b/src/interactive.cpp @@ -34,8 +34,8 @@ namespace params = query::params; // Evaluates a Cypher expression server-side and returns a copy of the resulting // value. Existing parameters are made available to the expression. mg_memory::MgValuePtr EvaluateParamExpression(mg_session *session, const std::string &expression, - const params::ParamStore &store) { - auto result = query::ExecuteQuery(session, "RETURN " + expression, store.AsMap().get()); + const params::ParamStore &store, const std::string &db) { + auto result = query::ExecuteQuery(session, "RETURN " + expression, store.AsMap().get(), db); if (result.records.empty() || mg_list_size(result.records.front().get()) == 0) { throw utils::ClientQueryException("expression did not produce a value"); } @@ -59,7 +59,8 @@ void ListParams(const params::ParamStore &store) { // Handles a `:param`/`:params` command line. Query-level failures (e.g. a bad // expression) are reported without aborting the shell; fatal connection // failures propagate to the reconnect logic in Run. -void HandleParamCommand(mg_session *session, params::ParamStore &store, const std::string &line) { +void HandleParamCommand(mg_session *session, params::ParamStore &store, const std::string &line, + const std::string &db) { const auto parsed = params::ParseParamCommand(line); if (!parsed.command) { console::EchoFailure("Invalid parameter command", parsed.error); @@ -68,7 +69,7 @@ void HandleParamCommand(mg_session *session, params::ParamStore &store, const st switch (parsed.command->kind) { case params::ParamCommand::Kind::kSet: try { - auto value = EvaluateParamExpression(session, parsed.command->expression, store); + auto value = EvaluateParamExpression(session, parsed.command->expression, store, db); store.Set(parsed.command->name, value.get()); console::EchoInfo("Set parameter '" + parsed.command->name + "'"); } catch (const utils::ClientQueryException &e) { @@ -153,8 +154,8 @@ int Run(utils::bolt::Config &bolt_config, const std::string &history, bool no_hi }; int num_retries = 3; - auto session = MakeBoltSession(bolt_config); - if (session.get() == nullptr) { + utils::bolt::RoutedSession session(bolt_config); + if (!session.Connected()) { cleanup_resources(); return 1; } @@ -179,7 +180,7 @@ int Run(utils::bolt::Config &bolt_config, const std::string &history, bool no_hi try { if (query->is_param_command) { - HandleParamCommand(session.get(), param_store, query->query); + HandleParamCommand(session.Get(), param_store, query->query, bolt_config.db); auto history_ret = save_history(); if (history_ret != 0) { cleanup_resources(); @@ -187,7 +188,12 @@ int Run(utils::bolt::Config &bolt_config, const std::string &history, bool no_hi } continue; } - auto ret = query::ExecuteQuery(session.get(), query->query, param_store.AsMap().get()); + // Resolve the session (may proactively re-route) before observing this query: the re-route decision must + // use the transaction state as it was *before* this query, so a COMMIT/ROLLBACK still runs on the session + // that holds the open transaction. + auto *session_ptr = session.Get(); + session.ObserveQuery(query->query); + auto ret = query::ExecuteQuery(session_ptr, query->query, param_store.AsMap().get(), bolt_config.db); if (ret.records.size() > 0) { Output(ret.header, ret.records, output_opts, csv_opts); } @@ -220,14 +226,12 @@ int Run(utils::bolt::Config &bolt_config, const std::string &history, bool no_hi console::EchoFailure("Client received connection exception", e.what()); console::EchoInfo("Trying to reconnect..."); bool is_connected = false; - session.reset(nullptr); while (num_retries > 0) { --num_retries; - session = utils::bolt::MakeBoltSession(bolt_config); - if (session.get() == nullptr) { - console::EchoFailure("Connection failure", mg_session_error(session.get())); - session.reset(nullptr); - } else { + // In routed mode this re-fetches the routing table (failover); in + // direct mode it reconnects to the same instance. + session.Reconnect(); + if (session.Connected()) { is_connected = true; break; } diff --git a/src/main.cpp b/src/main.cpp index 2aeb28ab..eb2103a4 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -60,6 +60,12 @@ DEFINE_int32(port, 7687, "Server port."); DEFINE_string(username, "", "Database username."); DEFINE_string(password, "", "Database password."); DEFINE_bool(use_ssl, false, "Use SSL when connecting to the server."); +DEFINE_string(connection_type, "direct", + "(direct|routing) If routing, uses client-side routing protocol to connect to the main data instance."); +// Selects the target database (sent in the ROUTE extra and in every RUN/BEGIN extra). When empty, the +// server's default database is used. Multi-database selection requires Memgraph enterprise; community ignores it. +DEFINE_string(db, "", "Database to use. When set, queries run against this database (enterprise multi-tenancy); " + "a non-existent database fails with 'Unknown database name'. Empty uses the default database."); // output DEFINE_bool(fit_to_screen, false, "Fit output width to screen width."); @@ -180,13 +186,19 @@ int main(int argc, char **argv) { #endif /* _WIN32 */ - utils::bolt::Config bolt_config{ - .host = FLAGS_host, - .port = FLAGS_port, - .username = FLAGS_username, - .password = FLAGS_password, - .use_ssl = FLAGS_use_ssl, - }; + auto const connection_type = utils::ToLowerCase(FLAGS_connection_type); + if (connection_type != "direct" && connection_type != "routing") { + console::EchoFailure("Unsupported connection type!", "Connection type can be 'direct' or 'routing'."); + return 1; + } + + utils::bolt::Config bolt_config{.db = FLAGS_db, + .host = FLAGS_host, + .username = FLAGS_username, + .password = FLAGS_password, + .port = FLAGS_port, + .use_ssl = FLAGS_use_ssl, + .routed_connection = connection_type == "routing"}; if (console::is_a_tty(STDIN_FILENO)) { // INTERACTIVE auto const history_file = std::invoke([&]() -> std::string { diff --git a/src/serial_import.cpp b/src/serial_import.cpp index 6c6c9cd3..d8e45372 100644 --- a/src/serial_import.cpp +++ b/src/serial_import.cpp @@ -21,8 +21,8 @@ using namespace std::string_literals; int Run(const utils::bolt::Config &bolt_config, const format::CsvOptions &csv_opts, const format::OutputOptions &output_opts) { - auto session = MakeBoltSession(bolt_config); - if (session.get() == nullptr) { + utils::bolt::RoutedSession session(bolt_config); + if (!session.Connected()) { return 1; } @@ -36,7 +36,11 @@ int Run(const utils::bolt::Config &bolt_config, const format::CsvOptions &csv_op } try { - auto ret = query::ExecuteQuery(session.get(), query->query); + // Resolve the session (may proactively re-route) before observing this query so transaction-control + // statements in a dump stream are not interrupted by a re-route mid-transaction. + auto *session_ptr = session.Get(); + session.ObserveQuery(query->query); + auto ret = query::ExecuteQuery(session_ptr, query->query, nullptr, bolt_config.db); if (ret.records.size() > 0) { Output(ret.header, ret.records, output_opts, csv_opts); } diff --git a/src/utils/bolt.cpp b/src/utils/bolt.cpp index a9bede90..8401b626 100644 --- a/src/utils/bolt.cpp +++ b/src/utils/bolt.cpp @@ -15,13 +15,20 @@ #include "bolt.hpp" +#include +#include + #include "gflags/gflags.h" namespace utils::bolt { using namespace std::string_literals; -mg_memory::MgSessionPtr MakeBoltSession(const Config &config) { +namespace { + +// Opens a direct connection to config.host:config.port (the same logic that +// MakeBoltSession has always used). Returns a null MgSessionPtr on failure. +mg_memory::MgSessionPtr MakeDirectSession(const Config &config) { std::string bolt_client_version = "mg/"s + gflags::VersionString(); mg_memory::MgSessionParamsPtr params = mg_memory::MakeCustomUnique(mg_session_params_make()); if (!params) { @@ -49,4 +56,205 @@ mg_memory::MgSessionPtr MakeBoltSession(const Config &config) { return session; } +// Splits a "host:port" address on its last ':' so IPv6 / hostnames with +// embedded colons still parse. Returns std::nullopt on a malformed address. +std::optional> SplitHostPort(const std::string &address) { + const auto colon = address.rfind(':'); + if (colon == std::string::npos || colon == 0 || colon + 1 >= address.size()) { + return std::nullopt; + } + const std::string host = address.substr(0, colon); + const std::string port_str = address.substr(colon + 1); + try { + size_t consumed = 0; + const int port = std::stoi(port_str, &consumed); + if (consumed != port_str.size() || port <= 0 || port > 65535) { + return std::nullopt; + } + return std::make_pair(host, port); + } catch (const std::exception &) { + return std::nullopt; + } +} + +// Reads an mg_string value into a std::string (mg_string is not null-terminated). +std::string ToString(const mg_string *str) { + if (!str) { + return {}; + } + return std::string(mg_string_data(str), mg_string_size(str)); +} + +// Eagerly checks that `db` can be selected on `session` by running a throwaway query that carries it, so an +// unknown database fails at connect time instead of on the user's first query. Returns false (after reporting) +// ONLY when the database itself is unknown. Any other rejection is not a database problem and must not abort +// the connection: most importantly, a direct connection to a coordinator rejects all Cypher ("Coordinator can +// run only coordinator queries!"), yet must remain usable for coordinator commands. +bool ValidateDatabase(mg_session *session, const std::string &db) { + if (db.empty()) { + return true; + } + try { + query::ExecuteQuery(session, "RETURN 1", nullptr, db); + return true; + } catch (const utils::ClientQueryException &e) { + if (utils::ToLowerCase(e.what()).find("unknown database") != std::string::npos) { + console::EchoFailure("Database selection failure", e.what()); + return false; + } + // Not a database error (e.g. a coordinator that can't run Cypher); let the connection proceed. + return true; + } catch (const utils::ClientFatalException &e) { + console::EchoFailure("Connection failure", e.what()); + return false; + } +} + +} // namespace + +mg_memory::MgSessionPtr MakeBoltSession(const Config &config) { + auto session = MakeDirectSession(config); + if (session.get() != nullptr && !ValidateDatabase(session.get(), config.db)) { + return mg_memory::MakeCustomUnique(nullptr); + } + return session; +} + +// Currently, it sends both write and read queries to the current main +mg_memory::MgSessionPtr MakeRoutedBoltSession(const Config &config, std::chrono::steady_clock::time_point *expiry_out) { + // 1. Connect to the coordinator. + Config coord_config = config; + coord_config.routed_connection = false; + auto coord = MakeDirectSession(coord_config); + if (coord.get() == nullptr) { + return mg_memory::MakeCustomUnique(nullptr); + } + + // 2. Build the routing context and extra map, then send the ROUTE message. + const std::string coord_address = config.host + ":" + std::to_string(config.port); + mg_memory::MgMapPtr routing = mg_memory::MakeCustomUnique(mg_map_make_empty(1)); + mg_memory::MgMapPtr extra = mg_memory::MakeCustomUnique(mg_map_make_empty(1)); + if (!routing || !extra) { + console::EchoFailure("Routing failure", "out of memory, failed to allocate routing maps"); + return mg_memory::MakeCustomUnique(nullptr); + } + mg_map_insert(routing.get(), "address", mg_value_make_string(coord_address.c_str())); + if (!config.db.empty()) { + mg_map_insert(extra.get(), "db", mg_value_make_string(config.db.c_str())); + } + + mg_map *rt_raw = nullptr; + const int status = mg_session_route(coord.get(), routing.get(), nullptr, extra.get(), &rt_raw); + // Take ownership immediately so the routing table is always freed. + mg_memory::MgMapPtr rt = mg_memory::MakeCustomUnique(rt_raw); + if (status != 0) { + console::EchoFailure("Routing failure", mg_session_error(coord.get())); + return mg_memory::MakeCustomUnique(nullptr); + } + + // 3. Parse the routing table: TTL (seconds) and the WRITE (main) instance. + const mg_value *ttl_val = mg_map_at(rt.get(), "ttl"); + if (ttl_val != nullptr && mg_value_get_type(ttl_val) == MG_VALUE_TYPE_INTEGER && expiry_out != nullptr) { + *expiry_out = std::chrono::steady_clock::now() + std::chrono::seconds(mg_value_integer(ttl_val)); + } + + const mg_value *servers_val = mg_map_at(rt.get(), "servers"); + if (servers_val == nullptr) { + console::EchoFailure("Routing failure", "routing table has no 'servers' entry"); + return mg_memory::MakeCustomUnique(nullptr); + } + const mg_list *servers = mg_value_list(servers_val); + std::optional write_address; + for (uint32_t i = 0; i < mg_list_size(servers); ++i) { + const mg_value *server_val = mg_list_at(servers, i); + const mg_map *server = mg_value_map(server_val); + if (server == nullptr) { + continue; + } + const mg_value *role_val = mg_map_at(server, "role"); + if (role_val == nullptr || ToString(mg_value_string(role_val)) != "WRITE") { + continue; + } + const mg_value *addresses_val = mg_map_at(server, "addresses"); + if (addresses_val == nullptr) { + continue; + } + const mg_list *addresses = mg_value_list(addresses_val); + if (mg_list_size(addresses) == 0) { + continue; + } + write_address = ToString(mg_value_string(mg_list_at(addresses, 0))); + break; + } + + if (!write_address || write_address->empty()) { + console::EchoFailure("Routing failure", "no WRITE (main) instance in routing table"); + return mg_memory::MakeCustomUnique(nullptr); + } + + const auto host_port = SplitHostPort(*write_address); + if (!host_port) { + console::EchoFailure("Routing failure", "could not parse WRITE instance address '" + *write_address + "'"); + return mg_memory::MakeCustomUnique(nullptr); + } + + // 4. Connect directly to the resolved main and drop the coordinator session. This goes through + // MakeBoltSession (not MakeDirectSession) so the configured db is eagerly validated on the main, which runs + // Cypher. The first/coordinator session above is intentionally never probed (coordinators reject Cypher). + Config main_config = config; + main_config.host = host_port->first; + main_config.port = host_port->second; + main_config.routed_connection = false; + return MakeBoltSession(main_config); +} + +RoutedSession::RoutedSession(Config config) + : config_(std::move(config)), + session_(mg_memory::MakeCustomUnique(nullptr)), + routed_(config_.routed_connection) { + Rebuild(); +} + +void RoutedSession::Rebuild() { + in_transaction_ = false; + if (routed_) { + session_ = MakeRoutedBoltSession(config_, &expiry_); + } else { + // MakeBoltSession (not MakeDirectSession) so the configured db is eagerly validated on the direct path too. + session_ = MakeBoltSession(config_); + } +} + +mg_session *RoutedSession::Get() { + // Proactive TTL re-route: once the previous routing table has expired, fetch a fresh one (and possibly a new + // main) before handing back the session. Suppressed while an explicit transaction is open (re-routing would + // silently drop it) and made non-destructive: a failed re-route keeps the existing working session rather than + // discarding a healthy connection over a transient coordinator hiccup. + if (routed_ && !in_transaction_ && std::chrono::steady_clock::now() >= expiry_) { + auto refreshed = MakeRoutedBoltSession(config_, &expiry_); + if (refreshed.get() != nullptr || session_.get() == nullptr) { + session_ = std::move(refreshed); + } else { + // Keep the still-usable session; back off so we retry periodically instead of on every call. + expiry_ = std::chrono::steady_clock::now() + kRerouteRetryBackoffSec; + } + } + return session_.get(); +} + +void RoutedSession::Reconnect() { Rebuild(); } + +bool RoutedSession::Connected() const { return session_.get() != nullptr; } + +void RoutedSession::ObserveQuery(const std::string &query) { + // Track explicit-transaction state from the transaction-control keyword so proactive re-routing can avoid + // tearing down an open transaction. Best-effort: matches the leading keyword of the trimmed query. + const auto upper = utils::ToUpperCase(utils::Trim(query)); + if (upper.rfind("BEGIN", 0) == 0) { + in_transaction_ = true; + } else if (upper.rfind("COMMIT", 0) == 0 || upper.rfind("ROLLBACK", 0) == 0) { + in_transaction_ = false; + } +} + } // namespace utils::bolt diff --git a/src/utils/bolt.hpp b/src/utils/bolt.hpp index 1faa82cc..04b5f7c8 100644 --- a/src/utils/bolt.hpp +++ b/src/utils/bolt.hpp @@ -15,18 +15,73 @@ #pragma once +#include + #include "utils.hpp" namespace utils::bolt { struct Config { + std::string db; std::string host; - int port; std::string username; std::string password; + int port; bool use_ssl; + bool routed_connection{false}; // when true, uses routing from coordinators }; +// Connects directly to config.host:config.port. Returns a null MgSessionPtr on +// failure (after reporting via console::EchoFailure). mg_memory::MgSessionPtr MakeBoltSession(const Config &config); +// Connects to the coordinator at config.host:config.port, fetches the routing +// table via a Bolt ROUTE message, locates the WRITE (main) data instance and +// returns a direct session to it. On success, *expiry_out is set to the point +// in time at which the routing table's TTL expires. Returns a null MgSessionPtr +// on failure (after reporting via console::EchoFailure). +mg_memory::MgSessionPtr MakeRoutedBoltSession(const Config &config, std::chrono::steady_clock::time_point *expiry_out); + +// A session abstraction that hides whether the connection is direct or routed. +// In routed mode it transparently re-fetches the routing table (re-routes) +// once the previous routing table's TTL has expired, keeping both connection +// modes on a single code path for the caller. +class RoutedSession { + public: + explicit RoutedSession(Config config); + + // Returns the underlying session. In routed mode, if the routing table TTL + // has expired (and no explicit transaction is open), transparently re-routes + // first; a failed re-route keeps the existing session. May return nullptr if + // a (re-)connection attempt failed and there is no prior session to fall back + // on. + mg_session *Get(); + + // Forces a rebuild of the session (direct or routed per config). Used by the + // interactive fatal-error path to drive failover. + void Reconnect(); + + // True if a session is currently established. + bool Connected() const; + + // Observes a query about to be executed so the session can track explicit + // transaction boundaries (BEGIN/COMMIT/ROLLBACK) and suppress proactive + // re-routing while a transaction is open. Safe to call in direct mode (no-op + // effect on routing). + void ObserveQuery(const std::string &query); + + private: + // Backoff applied after a failed proactive re-route so we don't retry on + // every Get() while the coordinator is briefly unreachable. + static constexpr std::chrono::seconds kRerouteRetryBackoffSec{2}; + + void Rebuild(); + + Config config_; + mg_memory::MgSessionPtr session_; + std::chrono::steady_clock::time_point expiry_{}; + bool routed_; + bool in_transaction_{false}; +}; + } // namespace utils::bolt diff --git a/src/utils/utils.cpp b/src/utils/utils.cpp index 71f398a8..25347533 100644 --- a/src/utils/utils.cpp +++ b/src/utils/utils.cpp @@ -16,13 +16,14 @@ #include #include +#include #include #include #include #include +#include #include #include -#include #include #include #include @@ -83,6 +84,11 @@ std::string ToUpperCase(std::string s) { return s; } +auto ToLowerCase(std::string s) -> std::string { + std::transform(s.begin(), s.end(), s.begin(), [](char c) { return tolower(c); }); + return s; +} + std::string Trim(const std::string &s) { auto begin = s.begin(); auto end = s.end(); @@ -314,7 +320,8 @@ void PrintValue(std::ostream &os, const mg_date_time *date_time) { } void PrintValue(std::ostream &os, const mg_date_time_zone_id *date_time_zone_id) { - PrintDateTimeComponents(os, mg_date_time_zone_id_seconds(date_time_zone_id), mg_date_time_zone_id_nanoseconds(date_time_zone_id)); + PrintDateTimeComponents(os, mg_date_time_zone_id_seconds(date_time_zone_id), + mg_date_time_zone_id_nanoseconds(date_time_zone_id)); os << "["; PrintStringUnescaped(os, mg_date_time_zone_id_timezone_name(date_time_zone_id)); os << "]"; @@ -947,8 +954,17 @@ void PrintQueryInfo(const Query &query) { std::cout << "line: " << query.line_number << " index: " << query.index << " query: " << query.query << std::endl; } -QueryResult ExecuteQuery(mg_session *session, const std::string &query, const mg_map *params) { - int status = mg_session_run(session, query.c_str(), params, nullptr, nullptr, nullptr); +QueryResult ExecuteQuery(mg_session *session, const std::string &query, const mg_map *params, const std::string &db) { + // Select the target database via the RUN extra metadata when a db is configured (multi-tenant select). + mg_memory::MgMapPtr extra = mg_memory::MakeCustomUnique(nullptr); + if (!db.empty()) { + extra = mg_memory::MakeCustomUnique(mg_map_make_empty(1)); + if (!extra) { + throw utils::ClientFatalException("out of memory, failed to allocate the RUN extra map"); + } + mg_map_insert(extra.get(), "db", mg_value_make_string(db.c_str())); + } + int status = mg_session_run(session, query.c_str(), params, extra.get(), nullptr, nullptr); auto start = std::chrono::system_clock::now(); if (status != 0) { if (mg_session_status(session) == MG_SESSION_BAD) { @@ -1043,13 +1059,24 @@ void PrintBatchesInfo(const std::vector &batches) { } } -BatchResult ExecuteBatch(mg_session *session, const Batch &batch) { +BatchResult ExecuteBatch(mg_session *session, const Batch &batch, const std::string &db) { if (session == nullptr) { std::cout << "Session uninitialized" << std::endl; return BatchResult{.is_executed = false}; } mg_result *result; - auto begin_status = mg_session_begin_transaction(session, nullptr); + // Select the target database for the whole transaction via the BEGIN extra metadata. Once inside the explicit + // transaction the per-query RUN extra is ignored by the server, so the inner ExecuteQuery calls omit the db. + mg_memory::MgMapPtr begin_extra = mg_memory::MakeCustomUnique(nullptr); + if (!db.empty()) { + begin_extra = mg_memory::MakeCustomUnique(mg_map_make_empty(1)); + if (!begin_extra) { + std::cout << "Unable to start transaction: out of memory" << std::endl; + return BatchResult{.is_executed = false}; + } + mg_map_insert(begin_extra.get(), "db", mg_value_make_string(db.c_str())); + } + auto begin_status = mg_session_begin_transaction(session, begin_extra.get()); if (begin_status != 0) { auto error = mg_session_error(session); std::cout << "Unable to start transaction: " << error << std::endl; diff --git a/src/utils/utils.hpp b/src/utils/utils.hpp index 62faaf17..31aa4a7c 100644 --- a/src/utils/utils.hpp +++ b/src/utils/utils.hpp @@ -65,6 +65,11 @@ fs::path GetUserHomeDir(); */ std::string ToUpperCase(std::string s); +/** + * return string with all lowercased characters (locale independent) + */ +auto ToLowerCase(std::string s) -> std::string; + /** * removes whitespace characters from the start and from the end of a string. * @@ -282,8 +287,12 @@ struct BatchResult { // The extra part is preserved for the next GetQuery call std::optional GetQuery(Replxx *replxx_instance, bool collect_info = false); -QueryResult ExecuteQuery(mg_session *session, const std::string &query, const mg_map *params = nullptr); -BatchResult ExecuteBatch(mg_session *session, const Batch &batch); +// When `db` is non-empty it is sent as the `db` field of the RUN extra metadata, selecting the target database +// for multi-tenant (enterprise) Memgraph. Community ignores it; enterprise validates it and fails with an +// "Unknown database name" error if it doesn't exist. +QueryResult ExecuteQuery(mg_session *session, const std::string &query, const mg_map *params = nullptr, + const std::string &db = ""); +BatchResult ExecuteBatch(mg_session *session, const Batch &batch, const std::string &db = ""); } // namespace query