Merge pull request #11140 from DeterminateSystems/protocol-features

WorkerProto: Support fine-grained protocol feature negotiation
This commit is contained in:
Eelco Dolstra 2024-07-31 17:47:38 +02:00 committed by GitHub
commit ed0934b884
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 127 additions and 29 deletions

View file

@ -1025,19 +1025,20 @@ void processConnection(
#endif #endif
/* Exchange the greeting. */ /* Exchange the greeting. */
WorkerProto::Version clientVersion = auto [protoVersion, features] =
WorkerProto::BasicServerConnection::handshake( WorkerProto::BasicServerConnection::handshake(
to, from, PROTOCOL_VERSION); to, from, PROTOCOL_VERSION, WorkerProto::allFeatures);
if (clientVersion < 0x10a) if (protoVersion < 0x10a)
throw Error("the Nix client version is too old"); throw Error("the Nix client version is too old");
WorkerProto::BasicServerConnection conn; WorkerProto::BasicServerConnection conn;
conn.to = std::move(to); conn.to = std::move(to);
conn.from = std::move(from); conn.from = std::move(from);
conn.protoVersion = clientVersion; conn.protoVersion = protoVersion;
conn.features = features;
auto tunnelLogger = new TunnelLogger(conn.to, clientVersion); auto tunnelLogger = new TunnelLogger(conn.to, protoVersion);
auto prevLogger = nix::logger; auto prevLogger = nix::logger;
// FIXME // FIXME
if (!recursive) if (!recursive)

View file

@ -73,8 +73,11 @@ void RemoteStore::initConnection(Connection & conn)
StringSink saved; StringSink saved;
TeeSource tee(conn.from, saved); TeeSource tee(conn.from, saved);
try { try {
conn.protoVersion = WorkerProto::BasicClientConnection::handshake( auto [protoVersion, features] = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION); conn.to, tee, PROTOCOL_VERSION,
WorkerProto::allFeatures);
conn.protoVersion = protoVersion;
conn.features = features;
} catch (SerialisationError & e) { } catch (SerialisationError & e) {
/* In case the other side is waiting for our input, close /* In case the other side is waiting for our input, close
it. */ it. */
@ -88,6 +91,9 @@ void RemoteStore::initConnection(Connection & conn)
static_cast<WorkerProto::ClientHandshakeInfo &>(conn) = conn.postHandshake(*this); static_cast<WorkerProto::ClientHandshakeInfo &>(conn) = conn.postHandshake(*this);
for (auto & feature : conn.features)
debug("negotiated feature '%s'", feature);
auto ex = conn.processStderrReturn(); auto ex = conn.processStderrReturn();
if (ex) std::rethrow_exception(ex); if (ex) std::rethrow_exception(ex);
} }

View file

@ -5,6 +5,8 @@
namespace nix { namespace nix {
const std::set<WorkerProto::Feature> WorkerProto::allFeatures{};
WorkerProto::BasicClientConnection::~BasicClientConnection() WorkerProto::BasicClientConnection::~BasicClientConnection()
{ {
try { try {
@ -137,8 +139,21 @@ void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, S
} }
} }
WorkerProto::Version static std::set<WorkerProto::Feature>
WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) intersectFeatures(const std::set<WorkerProto::Feature> & a, const std::set<WorkerProto::Feature> & b)
{
std::set<WorkerProto::Feature> res;
for (auto & x : a)
if (b.contains(x))
res.insert(x);
return res;
}
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicClientConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{ {
to << WORKER_MAGIC_1 << localVersion; to << WORKER_MAGIC_1 << localVersion;
to.flush(); to.flush();
@ -153,11 +168,24 @@ WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from,
if (GET_PROTOCOL_MINOR(daemonVersion) < 10) if (GET_PROTOCOL_MINOR(daemonVersion) < 10)
throw Error("the Nix daemon version is too old"); throw Error("the Nix daemon version is too old");
return std::min(daemonVersion, localVersion); auto protoVersion = std::min(daemonVersion, localVersion);
/* Exchange features. */
std::set<WorkerProto::Feature> daemonFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
to << supportedFeatures;
to.flush();
daemonFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
} }
WorkerProto::Version return {protoVersion, intersectFeatures(daemonFeatures, supportedFeatures)};
WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) }
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicServerConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{ {
unsigned int magic = readInt(from); unsigned int magic = readInt(from);
if (magic != WORKER_MAGIC_1) if (magic != WORKER_MAGIC_1)
@ -165,7 +193,18 @@ WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from,
to << WORKER_MAGIC_2 << localVersion; to << WORKER_MAGIC_2 << localVersion;
to.flush(); to.flush();
auto clientVersion = readInt(from); auto clientVersion = readInt(from);
return std::min(clientVersion, localVersion);
auto protoVersion = std::min(clientVersion, localVersion);
/* Exchange features. */
std::set<WorkerProto::Feature> clientFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
clientFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
to << supportedFeatures;
to.flush();
}
return {protoVersion, intersectFeatures(clientFeatures, supportedFeatures)};
} }
WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store) WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store)

View file

@ -23,6 +23,11 @@ struct WorkerProto::BasicConnection
*/ */
WorkerProto::Version protoVersion; WorkerProto::Version protoVersion;
/**
* The set of features that both sides support.
*/
std::set<Feature> features;
/** /**
* Coercion to `WorkerProto::ReadConn`. This makes it easy to use the * Coercion to `WorkerProto::ReadConn`. This makes it easy to use the
* factored out serve protocol serializers with a * factored out serve protocol serializers with a
@ -72,8 +77,8 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
/** /**
* Establishes connection, negotiating version. * Establishes connection, negotiating version.
* *
* @return the version provided by the other side of the * @return the minimum version supported by both sides and the set
* connection. * of protocol features supported by both sides.
* *
* @param to Taken by reference to allow for various error handling * @param to Taken by reference to allow for various error handling
* mechanisms. * mechanisms.
@ -82,8 +87,15 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
* handling mechanisms. * handling mechanisms.
* *
* @param localVersion Our version which is sent over * @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/ */
static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); // FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);
/** /**
* After calling handshake, must call this to exchange some basic * After calling handshake, must call this to exchange some basic
@ -138,8 +150,15 @@ struct WorkerProto::BasicServerConnection : WorkerProto::BasicConnection
* handling mechanisms. * handling mechanisms.
* *
* @param localVersion Our version which is sent over * @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/ */
static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); // FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);
/** /**
* After calling handshake, must call this to exchange some basic * After calling handshake, must call this to exchange some basic

View file

@ -11,7 +11,9 @@ namespace nix {
#define WORKER_MAGIC_1 0x6e697863 #define WORKER_MAGIC_1 0x6e697863
#define WORKER_MAGIC_2 0x6478696f #define WORKER_MAGIC_2 0x6478696f
#define PROTOCOL_VERSION (1 << 8 | 37) /* Note: you generally shouldn't change the protocol version. Define a
new `WorkerProto::Feature` instead. */
#define PROTOCOL_VERSION (1 << 8 | 38)
#define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00) #define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00)
#define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff) #define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff)
@ -131,6 +133,10 @@ struct WorkerProto
{ {
WorkerProto::Serialise<T>::write(store, conn, t); WorkerProto::Serialise<T>::write(store, conn, t);
} }
using Feature = std::string;
static const std::set<Feature> allFeatures;
}; };
enum struct WorkerProto::Op : uint64_t enum struct WorkerProto::Op : uint64_t

View file

@ -658,15 +658,15 @@ TEST_F(WorkerProtoTest, handshake_log)
FdSink out { toServer.writeSide.get() }; FdSink out { toServer.writeSide.get() };
FdSource in0 { toClient.readSide.get() }; FdSource in0 { toClient.readSide.get() };
TeeSource in { in0, toClientLog }; TeeSource in { in0, toClientLog };
clientResult = WorkerProto::BasicClientConnection::handshake( clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion); out, in, defaultVersion, {}));
}); });
{ {
FdSink out { toClient.writeSide.get() }; FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() }; FdSource in { toServer.readSide.get() };
WorkerProto::BasicServerConnection::handshake( WorkerProto::BasicServerConnection::handshake(
out, in, defaultVersion); out, in, defaultVersion, {});
}; };
thread.join(); thread.join();
@ -675,6 +675,33 @@ TEST_F(WorkerProtoTest, handshake_log)
}); });
} }
TEST_F(WorkerProtoTest, handshake_features)
{
Pipe toClient, toServer;
toClient.create();
toServer.create();
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> clientResult;
auto clientThread = std::thread([&]() {
FdSink out { toServer.writeSide.get() };
FdSource in { toClient.readSide.get() };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, 123, {"bar", "aap", "mies", "xyzzy"});
});
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
auto daemonResult = WorkerProto::BasicServerConnection::handshake(
out, in, 456, {"foo", "bar", "xyzzy"});
clientThread.join();
EXPECT_EQ(clientResult, daemonResult);
EXPECT_EQ(std::get<0>(clientResult), 123);
EXPECT_EQ(std::get<1>(clientResult), std::set<WorkerProto::Feature>({"bar", "xyzzy"}));
}
/// Has to be a `BufferedSink` for handshake. /// Has to be a `BufferedSink` for handshake.
struct NullBufferedSink : BufferedSink { struct NullBufferedSink : BufferedSink {
void writeUnbuffered(std::string_view data) override { } void writeUnbuffered(std::string_view data) override { }
@ -686,8 +713,8 @@ TEST_F(WorkerProtoTest, handshake_client_replay)
NullBufferedSink nullSink; NullBufferedSink nullSink;
StringSource in { toClientLog }; StringSource in { toClientLog };
auto clientResult = WorkerProto::BasicClientConnection::handshake( auto clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion); nullSink, in, defaultVersion, {}));
EXPECT_EQ(clientResult, defaultVersion); EXPECT_EQ(clientResult, defaultVersion);
}); });
@ -705,13 +732,13 @@ TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws)
if (len < 8) { if (len < 8) {
EXPECT_THROW( EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake( WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion), nullSink, in, defaultVersion, {}),
EndOfFile); EndOfFile);
} else { } else {
// Not sure why cannot keep on checking for `EndOfFile`. // Not sure why cannot keep on checking for `EndOfFile`.
EXPECT_THROW( EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake( WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion), nullSink, in, defaultVersion, {}),
Error); Error);
} }
} }
@ -734,17 +761,17 @@ TEST_F(WorkerProtoTest, handshake_client_corrupted_throws)
// magic bytes don't match // magic bytes don't match
EXPECT_THROW( EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake( WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion), nullSink, in, defaultVersion, {}),
Error); Error);
} else if (idx < 8 || idx >= 12) { } else if (idx < 8 || idx >= 12) {
// Number out of bounds // Number out of bounds
EXPECT_THROW( EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake( WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion), nullSink, in, defaultVersion, {}),
SerialisationError); SerialisationError);
} else { } else {
auto ver = WorkerProto::BasicClientConnection::handshake( auto ver = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion); nullSink, in, defaultVersion, {}));
// `std::min` of this and the other version saves us // `std::min` of this and the other version saves us
EXPECT_EQ(ver, defaultVersion); EXPECT_EQ(ver, defaultVersion);
} }