WorkerProto: Support fine-grained protocol feature negotiation

Currently, the worker protocol has a version number that we increment
whenever we change something in the protocol. However, this can cause
a collision between Nix PRs / forks that make protocol changes
(e.g. PR #9857 increments the version, which could collide with
another PR). So instead, the client and daemon now exchange a set of
protocol features (such as `auth-forwarding`). They will use the
intersection of the sets of features, i.e. the features they both
support.

Note that protocol features are completely distinct from
`ExperimentalFeature`s.
This commit is contained in:
Eelco Dolstra 2024-07-19 15:48:19 +02:00
parent b13ba7490c
commit 3be7c0037e
6 changed files with 127 additions and 29 deletions

View file

@ -1025,19 +1025,20 @@ void processConnection(
#endif
/* Exchange the greeting. */
WorkerProto::Version clientVersion =
auto [protoVersion, features] =
WorkerProto::BasicServerConnection::handshake(
to, from, PROTOCOL_VERSION);
to, from, PROTOCOL_VERSION, WorkerProto::allFeatures);
if (clientVersion < 0x10a)
if (protoVersion < 0x10a)
throw Error("the Nix client version is too old");
WorkerProto::BasicServerConnection conn;
conn.to = std::move(to);
conn.from = std::move(from);
conn.protoVersion = clientVersion;
conn.protoVersion = protoVersion;
conn.features = features;
auto tunnelLogger = new TunnelLogger(conn.to, clientVersion);
auto tunnelLogger = new TunnelLogger(conn.to, protoVersion);
auto prevLogger = nix::logger;
// FIXME
if (!recursive)

View file

@ -73,8 +73,11 @@ void RemoteStore::initConnection(Connection & conn)
StringSink saved;
TeeSource tee(conn.from, saved);
try {
conn.protoVersion = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION);
auto [protoVersion, features] = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION,
WorkerProto::allFeatures);
conn.protoVersion = protoVersion;
conn.features = features;
} catch (SerialisationError & e) {
/* In case the other side is waiting for our input, close
it. */
@ -88,6 +91,9 @@ void RemoteStore::initConnection(Connection & conn)
static_cast<WorkerProto::ClientHandshakeInfo &>(conn) = conn.postHandshake(*this);
for (auto & feature : conn.features)
debug("negotiated feature '%s'", feature);
auto ex = conn.processStderrReturn();
if (ex) std::rethrow_exception(ex);
}

View file

@ -5,6 +5,8 @@
namespace nix {
const std::set<WorkerProto::Feature> WorkerProto::allFeatures{};
WorkerProto::BasicClientConnection::~BasicClientConnection()
{
try {
@ -137,8 +139,21 @@ void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, S
}
}
WorkerProto::Version
WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
static std::set<WorkerProto::Feature>
intersectFeatures(const std::set<WorkerProto::Feature> & a, const std::set<WorkerProto::Feature> & b)
{
std::set<WorkerProto::Feature> res;
for (auto & x : a)
if (b.contains(x))
res.insert(x);
return res;
}
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicClientConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
to << WORKER_MAGIC_1 << localVersion;
to.flush();
@ -153,11 +168,24 @@ WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from,
if (GET_PROTOCOL_MINOR(daemonVersion) < 10)
throw Error("the Nix daemon version is too old");
return std::min(daemonVersion, localVersion);
auto protoVersion = std::min(daemonVersion, localVersion);
/* Exchange features. */
std::set<WorkerProto::Feature> daemonFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
to << supportedFeatures;
to.flush();
daemonFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
}
return {protoVersion, intersectFeatures(daemonFeatures, supportedFeatures)};
}
WorkerProto::Version
WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicServerConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
unsigned int magic = readInt(from);
if (magic != WORKER_MAGIC_1)
@ -165,7 +193,18 @@ WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from,
to << WORKER_MAGIC_2 << localVersion;
to.flush();
auto clientVersion = readInt(from);
return std::min(clientVersion, localVersion);
auto protoVersion = std::min(clientVersion, localVersion);
/* Exchange features. */
std::set<WorkerProto::Feature> clientFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
clientFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
to << supportedFeatures;
to.flush();
}
return {protoVersion, intersectFeatures(clientFeatures, supportedFeatures)};
}
WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store)

View file

@ -23,6 +23,11 @@ struct WorkerProto::BasicConnection
*/
WorkerProto::Version protoVersion;
/**
* The set of features that both sides support.
*/
std::set<Feature> features;
/**
* Coercion to `WorkerProto::ReadConn`. This makes it easy to use the
* factored out serve protocol serializers with a
@ -72,8 +77,8 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
/**
* Establishes connection, negotiating version.
*
* @return the version provided by the other side of the
* connection.
* @return the minimum version supported by both sides and the set
* of protocol features supported by both sides.
*
* @param to Taken by reference to allow for various error handling
* mechanisms.
@ -82,8 +87,15 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);
/**
* After calling handshake, must call this to exchange some basic
@ -138,8 +150,15 @@ struct WorkerProto::BasicServerConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);
/**
* After calling handshake, must call this to exchange some basic

View file

@ -11,7 +11,9 @@ namespace nix {
#define WORKER_MAGIC_1 0x6e697863
#define WORKER_MAGIC_2 0x6478696f
#define PROTOCOL_VERSION (1 << 8 | 37)
/* Note: you generally shouldn't change the protocol version. Define a
new `WorkerProto::Feature` instead. */
#define PROTOCOL_VERSION (1 << 8 | 38)
#define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00)
#define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff)
@ -131,6 +133,10 @@ struct WorkerProto
{
WorkerProto::Serialise<T>::write(store, conn, t);
}
using Feature = std::string;
static const std::set<Feature> allFeatures;
};
enum struct WorkerProto::Op : uint64_t

View file

@ -658,15 +658,15 @@ TEST_F(WorkerProtoTest, handshake_log)
FdSink out { toServer.writeSide.get() };
FdSource in0 { toClient.readSide.get() };
TeeSource in { in0, toClientLog };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion);
clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion, {}));
});
{
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
WorkerProto::BasicServerConnection::handshake(
out, in, defaultVersion);
out, in, defaultVersion, {});
};
thread.join();
@ -675,6 +675,33 @@ TEST_F(WorkerProtoTest, handshake_log)
});
}
TEST_F(WorkerProtoTest, handshake_features)
{
Pipe toClient, toServer;
toClient.create();
toServer.create();
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> clientResult;
auto clientThread = std::thread([&]() {
FdSink out { toServer.writeSide.get() };
FdSource in { toClient.readSide.get() };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, 123, {"bar", "aap", "mies", "xyzzy"});
});
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
auto daemonResult = WorkerProto::BasicServerConnection::handshake(
out, in, 456, {"foo", "bar", "xyzzy"});
clientThread.join();
EXPECT_EQ(clientResult, daemonResult);
EXPECT_EQ(std::get<0>(clientResult), 123);
EXPECT_EQ(std::get<1>(clientResult), std::set<WorkerProto::Feature>({"bar", "xyzzy"}));
}
/// Has to be a `BufferedSink` for handshake.
struct NullBufferedSink : BufferedSink {
void writeUnbuffered(std::string_view data) override { }
@ -686,8 +713,8 @@ TEST_F(WorkerProtoTest, handshake_client_replay)
NullBufferedSink nullSink;
StringSource in { toClientLog };
auto clientResult = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));
EXPECT_EQ(clientResult, defaultVersion);
});
@ -705,13 +732,13 @@ TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws)
if (len < 8) {
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
EndOfFile);
} else {
// Not sure why cannot keep on checking for `EndOfFile`.
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
}
}
@ -734,17 +761,17 @@ TEST_F(WorkerProtoTest, handshake_client_corrupted_throws)
// magic bytes don't match
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
} else if (idx < 8 || idx >= 12) {
// Number out of bounds
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
SerialisationError);
} else {
auto ver = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto ver = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));
// `std::min` of this and the other version saves us
EXPECT_EQ(ver, defaultVersion);
}