diff --git a/src/libstore/daemon.cc b/src/libstore/daemon.cc index 6533b2f58..94f00cfb6 100644 --- a/src/libstore/daemon.cc +++ b/src/libstore/daemon.cc @@ -1025,19 +1025,20 @@ void processConnection( #endif /* Exchange the greeting. */ - WorkerProto::Version clientVersion = + auto [protoVersion, features] = WorkerProto::BasicServerConnection::handshake( - to, from, PROTOCOL_VERSION); + to, from, PROTOCOL_VERSION, WorkerProto::allFeatures); - if (clientVersion < 0x10a) + if (protoVersion < 0x10a) throw Error("the Nix client version is too old"); WorkerProto::BasicServerConnection conn; conn.to = std::move(to); conn.from = std::move(from); - conn.protoVersion = clientVersion; + conn.protoVersion = protoVersion; + conn.features = features; - auto tunnelLogger = new TunnelLogger(conn.to, clientVersion); + auto tunnelLogger = new TunnelLogger(conn.to, protoVersion); auto prevLogger = nix::logger; // FIXME if (!recursive) diff --git a/src/libstore/remote-store.cc b/src/libstore/remote-store.cc index ebb0864c5..555936c18 100644 --- a/src/libstore/remote-store.cc +++ b/src/libstore/remote-store.cc @@ -73,8 +73,11 @@ void RemoteStore::initConnection(Connection & conn) StringSink saved; TeeSource tee(conn.from, saved); try { - conn.protoVersion = WorkerProto::BasicClientConnection::handshake( - conn.to, tee, PROTOCOL_VERSION); + auto [protoVersion, features] = WorkerProto::BasicClientConnection::handshake( + conn.to, tee, PROTOCOL_VERSION, + WorkerProto::allFeatures); + conn.protoVersion = protoVersion; + conn.features = features; } catch (SerialisationError & e) { /* In case the other side is waiting for our input, close it. */ @@ -88,6 +91,9 @@ void RemoteStore::initConnection(Connection & conn) static_cast(conn) = conn.postHandshake(*this); + for (auto & feature : conn.features) + debug("negotiated feature '%s'", feature); + auto ex = conn.processStderrReturn(); if (ex) std::rethrow_exception(ex); } diff --git a/src/libstore/worker-protocol-connection.cc b/src/libstore/worker-protocol-connection.cc index 93d13d48e..a47dbb689 100644 --- a/src/libstore/worker-protocol-connection.cc +++ b/src/libstore/worker-protocol-connection.cc @@ -5,6 +5,8 @@ namespace nix { +const std::set WorkerProto::allFeatures{}; + WorkerProto::BasicClientConnection::~BasicClientConnection() { try { @@ -137,8 +139,21 @@ void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, S } } -WorkerProto::Version -WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) +static std::set +intersectFeatures(const std::set & a, const std::set & b) +{ + std::set res; + for (auto & x : a) + if (b.contains(x)) + res.insert(x); + return res; +} + +std::tuple> WorkerProto::BasicClientConnection::handshake( + BufferedSink & to, + Source & from, + WorkerProto::Version localVersion, + const std::set & supportedFeatures) { to << WORKER_MAGIC_1 << localVersion; to.flush(); @@ -153,11 +168,24 @@ WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, if (GET_PROTOCOL_MINOR(daemonVersion) < 10) throw Error("the Nix daemon version is too old"); - return std::min(daemonVersion, localVersion); + auto protoVersion = std::min(daemonVersion, localVersion); + + /* Exchange features. */ + std::set daemonFeatures; + if (GET_PROTOCOL_MINOR(protoVersion) >= 38) { + to << supportedFeatures; + to.flush(); + daemonFeatures = readStrings>(from); + } + + return {protoVersion, intersectFeatures(daemonFeatures, supportedFeatures)}; } -WorkerProto::Version -WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) +std::tuple> WorkerProto::BasicServerConnection::handshake( + BufferedSink & to, + Source & from, + WorkerProto::Version localVersion, + const std::set & supportedFeatures) { unsigned int magic = readInt(from); if (magic != WORKER_MAGIC_1) @@ -165,7 +193,18 @@ WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, to << WORKER_MAGIC_2 << localVersion; to.flush(); auto clientVersion = readInt(from); - return std::min(clientVersion, localVersion); + + auto protoVersion = std::min(clientVersion, localVersion); + + /* Exchange features. */ + std::set clientFeatures; + if (GET_PROTOCOL_MINOR(protoVersion) >= 38) { + clientFeatures = readStrings>(from); + to << supportedFeatures; + to.flush(); + } + + return {protoVersion, intersectFeatures(clientFeatures, supportedFeatures)}; } WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store) diff --git a/src/libstore/worker-protocol-connection.hh b/src/libstore/worker-protocol-connection.hh index 38287d08e..9c96195b5 100644 --- a/src/libstore/worker-protocol-connection.hh +++ b/src/libstore/worker-protocol-connection.hh @@ -23,6 +23,11 @@ struct WorkerProto::BasicConnection */ WorkerProto::Version protoVersion; + /** + * The set of features that both sides support. + */ + std::set features; + /** * Coercion to `WorkerProto::ReadConn`. This makes it easy to use the * factored out serve protocol serializers with a @@ -72,8 +77,8 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection /** * Establishes connection, negotiating version. * - * @return the version provided by the other side of the - * connection. + * @return the minimum version supported by both sides and the set + * of protocol features supported by both sides. * * @param to Taken by reference to allow for various error handling * mechanisms. @@ -82,8 +87,15 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection * handling mechanisms. * * @param localVersion Our version which is sent over + * + * @param features The protocol features that we support */ - static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); + // FIXME: this should probably be a constructor. + static std::tuple> handshake( + BufferedSink & to, + Source & from, + WorkerProto::Version localVersion, + const std::set & supportedFeatures); /** * After calling handshake, must call this to exchange some basic @@ -138,8 +150,15 @@ struct WorkerProto::BasicServerConnection : WorkerProto::BasicConnection * handling mechanisms. * * @param localVersion Our version which is sent over + * + * @param features The protocol features that we support */ - static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); + // FIXME: this should probably be a constructor. + static std::tuple> handshake( + BufferedSink & to, + Source & from, + WorkerProto::Version localVersion, + const std::set & supportedFeatures); /** * After calling handshake, must call this to exchange some basic diff --git a/src/libstore/worker-protocol.hh b/src/libstore/worker-protocol.hh index 9fc16d015..c356fa1bf 100644 --- a/src/libstore/worker-protocol.hh +++ b/src/libstore/worker-protocol.hh @@ -11,7 +11,9 @@ namespace nix { #define WORKER_MAGIC_1 0x6e697863 #define WORKER_MAGIC_2 0x6478696f -#define PROTOCOL_VERSION (1 << 8 | 37) +/* Note: you generally shouldn't change the protocol version. Define a + new `WorkerProto::Feature` instead. */ +#define PROTOCOL_VERSION (1 << 8 | 38) #define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00) #define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff) @@ -131,6 +133,10 @@ struct WorkerProto { WorkerProto::Serialise::write(store, conn, t); } + + using Feature = std::string; + + static const std::set allFeatures; }; enum struct WorkerProto::Op : uint64_t diff --git a/tests/unit/libstore/worker-protocol.cc b/tests/unit/libstore/worker-protocol.cc index c15120010..bbea9ed75 100644 --- a/tests/unit/libstore/worker-protocol.cc +++ b/tests/unit/libstore/worker-protocol.cc @@ -658,15 +658,15 @@ TEST_F(WorkerProtoTest, handshake_log) FdSink out { toServer.writeSide.get() }; FdSource in0 { toClient.readSide.get() }; TeeSource in { in0, toClientLog }; - clientResult = WorkerProto::BasicClientConnection::handshake( - out, in, defaultVersion); + clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake( + out, in, defaultVersion, {})); }); { FdSink out { toClient.writeSide.get() }; FdSource in { toServer.readSide.get() }; WorkerProto::BasicServerConnection::handshake( - out, in, defaultVersion); + out, in, defaultVersion, {}); }; thread.join(); @@ -675,6 +675,33 @@ TEST_F(WorkerProtoTest, handshake_log) }); } +TEST_F(WorkerProtoTest, handshake_features) +{ + Pipe toClient, toServer; + toClient.create(); + toServer.create(); + + std::tuple> clientResult; + + auto clientThread = std::thread([&]() { + FdSink out { toServer.writeSide.get() }; + FdSource in { toClient.readSide.get() }; + clientResult = WorkerProto::BasicClientConnection::handshake( + out, in, 123, {"bar", "aap", "mies", "xyzzy"}); + }); + + FdSink out { toClient.writeSide.get() }; + FdSource in { toServer.readSide.get() }; + auto daemonResult = WorkerProto::BasicServerConnection::handshake( + out, in, 456, {"foo", "bar", "xyzzy"}); + + clientThread.join(); + + EXPECT_EQ(clientResult, daemonResult); + EXPECT_EQ(std::get<0>(clientResult), 123); + EXPECT_EQ(std::get<1>(clientResult), std::set({"bar", "xyzzy"})); +} + /// Has to be a `BufferedSink` for handshake. struct NullBufferedSink : BufferedSink { void writeUnbuffered(std::string_view data) override { } @@ -686,8 +713,8 @@ TEST_F(WorkerProtoTest, handshake_client_replay) NullBufferedSink nullSink; StringSource in { toClientLog }; - auto clientResult = WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion); + auto clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, {})); EXPECT_EQ(clientResult, defaultVersion); }); @@ -705,13 +732,13 @@ TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws) if (len < 8) { EXPECT_THROW( WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion), + nullSink, in, defaultVersion, {}), EndOfFile); } else { // Not sure why cannot keep on checking for `EndOfFile`. EXPECT_THROW( WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion), + nullSink, in, defaultVersion, {}), Error); } } @@ -734,17 +761,17 @@ TEST_F(WorkerProtoTest, handshake_client_corrupted_throws) // magic bytes don't match EXPECT_THROW( WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion), + nullSink, in, defaultVersion, {}), Error); } else if (idx < 8 || idx >= 12) { // Number out of bounds EXPECT_THROW( WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion), + nullSink, in, defaultVersion, {}), SerialisationError); } else { - auto ver = WorkerProto::BasicClientConnection::handshake( - nullSink, in, defaultVersion); + auto ver = std::get<0>(WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion, {})); // `std::min` of this and the other version saves us EXPECT_EQ(ver, defaultVersion); }