Make the handshake mirrored again, for common code and to support future key rotation.

This commit is contained in:
Ian Gulliver
2015-02-08 19:50:09 +00:00
parent 7807df9575
commit 6a4a92f47a
2 changed files with 66 additions and 48 deletions

105
crypto.cc
View File

@@ -104,7 +104,7 @@ CryptoKey::~CryptoKey() {
void CryptoKey::WriteToFile(const std::string& filename) const { void CryptoKey::WriteToFile(const std::string& filename) const {
assert(is_set_); assert(is_set_);
int fd = open(filename.c_str(), O_WRONLY); int fd = open(filename.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0400);
assert(fd != -1); assert(fd != -1);
assert(write(fd, key_, key_bytes_) == key_bytes_); assert(write(fd, key_, key_bytes_) == key_bytes_);
assert(!close(fd)); assert(!close(fd));
@@ -124,6 +124,10 @@ const unsigned char* CryptoKey::Key() const {
return key_; return key_;
} }
bool CryptoKey::IsSet() const {
return is_set_;
}
unsigned char* CryptoKey::MutableKey() { unsigned char* CryptoKey::MutableKey() {
assert(!is_set_); assert(!is_set_);
return key_; return key_;
@@ -201,6 +205,25 @@ std::unique_ptr<TLVNode> CryptoPubConnBase::BuildSecureHandshake() {
return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake); return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake);
} }
std::unique_ptr<TLVNode> CryptoPubConnBase::BuildHandshake() {
auto secure_handshake = BuildSecureHandshake();
std::unique_ptr<TLVNode> handshake(new TLVNode(TLV_TYPE_HANDSHAKE));
PublicKey public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake->AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString()));
handshake->AppendChild(secure_handshake.release());
return handshake;
}
void CryptoPubConnBase::SendHandshake() {
auto handshake = BuildHandshake();
std::string out;
handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) {
assert(node.GetType() == TLV_TYPE_ENCRYPTED); assert(node.GetType() == TLV_TYPE_ENCRYPTED);
@@ -223,6 +246,40 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) {
return true; return true;
} }
bool CryptoPubConnBase::HandleHandshake(const TLVNode& node) {
if (node.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal("Protocol error (handshake; wrong message type)");
return false;
}
auto peer_public_key = node.FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_public_key) {
LogFatal("Protocol error (handshake; no public key)");
return false;
}
if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) {
LogFatal("Protocol error (handshake; wrong public key length)");
return false;
}
if (peer_public_key_.IsSet()) {
// We're the client and already know the server public key; we expect these to match.
// Eventually, we can do smarter things here to allow key rotation.
if (peer_public_key_.AsString() != peer_public_key->GetValue()) {
LogFatal("Protocol error (handshake; public key mismatch)");
return false;
}
} else {
peer_public_key_.FromString(peer_public_key->GetValue());
}
auto encrypted = node.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal("Protocol error (handshake; no encrypted portion)");
return false;
}
return HandleSecureHandshake(*encrypted);
}
void CryptoPubConnBase::EncryptSend(const TLVNode& node) { void CryptoPubConnBase::EncryptSend(const TLVNode& node) {
auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_secret_key_, peer_ephemeral_public_key_, node); auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_secret_key_, peer_ephemeral_public_key_, node);
std::string out; std::string out;
@@ -334,28 +391,7 @@ CryptoPubServerConnection::~CryptoPubServerConnection() {
} }
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_HANDSHAKE) { if (!HandleHandshake(decoded)) {
LogFatal("Protocol error (client handshake -- wrong message type)");
return;
}
auto peer_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_public_key) {
LogFatal("Protocol error (client handshake -- no public key)");
return;
}
if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) {
LogFatal("Protocol error (client handshake -- wrong public key length)");
return;
}
peer_public_key_.FromString(peer_public_key->GetValue());
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal("Protocol error (client handshake -- no encrypted portion)");
return;
}
if (!HandleSecureHandshake(*encrypted)) {
return; return;
} }
@@ -385,13 +421,6 @@ bool CryptoPubServerConnection::OnTunnelRequest(const TLVNode& message) {
return true; return true;
} }
void CryptoPubServerConnection::SendHandshake() {
auto handshake = BuildSecureHandshake();
std::string out;
handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
void CryptoPubServerConnection::OnError_(struct bufferevent* bev, const short what, void* this__) { void CryptoPubServerConnection::OnError_(struct bufferevent* bev, const short what, void* this__) {
auto this_ = (CryptoPubServerConnection*)this__; auto this_ = (CryptoPubServerConnection*)this__;
this_->OnError(what); this_->OnError(what);
@@ -432,7 +461,7 @@ CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address
} }
void CryptoPubClient::OnHandshake(const TLVNode& decoded) { void CryptoPubClient::OnHandshake(const TLVNode& decoded) {
if (!HandleSecureHandshake(decoded)) { if (!HandleHandshake(decoded)) {
return; return;
} }
@@ -463,20 +492,6 @@ void CryptoPubClient::OnConnect() {
SendHandshake(); SendHandshake();
} }
void CryptoPubClient::SendHandshake() {
auto secure_handshake = BuildSecureHandshake();
TLVNode handshake(TLV_TYPE_HANDSHAKE);
PublicKey public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString()));
handshake.AppendChild(secure_handshake.release());
std::string out;
handshake.Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
void CryptoPubClient::SendTunnelRequest() { void CryptoPubClient::SendTunnelRequest() {
TLVNode tunnel_request(TLV_TYPE_TUNNEL_REQUEST); TLVNode tunnel_request(TLV_TYPE_TUNNEL_REQUEST);
for (auto channel_bitrate : channel_bitrates_) { for (auto channel_bitrate : channel_bitrates_) {

View File

@@ -15,6 +15,7 @@ class CryptoKey {
void WriteToFile(const std::string& filename) const; void WriteToFile(const std::string& filename) const;
const unsigned char* Key() const; const unsigned char* Key() const;
bool IsSet() const;
unsigned char* MutableKey(); unsigned char* MutableKey();
void MarkSet(); void MarkSet();
@@ -67,7 +68,12 @@ class CryptoPubConnBase : public CryptoBase {
void LogFatal(const std::string& msg, void *obj=nullptr); void LogFatal(const std::string& msg, void *obj=nullptr);
std::unique_ptr<TLVNode> BuildSecureHandshake(); std::unique_ptr<TLVNode> BuildSecureHandshake();
std::unique_ptr<TLVNode> BuildHandshake();
void SendHandshake();
bool HandleSecureHandshake(const TLVNode& node); bool HandleSecureHandshake(const TLVNode& node);
bool HandleHandshake(const TLVNode& node);
void EncryptSend(const TLVNode& node); void EncryptSend(const TLVNode& node);
static void OnReadable_(struct bufferevent* bev, void* this__); static void OnReadable_(struct bufferevent* bev, void* this__);
@@ -123,8 +129,6 @@ class CryptoPubServerConnection : public CryptoPubConnBase {
static void OnError_(struct bufferevent* bev, const short what, void* this__); static void OnError_(struct bufferevent* bev, const short what, void* this__);
void OnError(const short what); void OnError(const short what);
void SendHandshake();
friend CryptoPubServer; friend CryptoPubServer;
}; };
@@ -145,7 +149,6 @@ class CryptoPubClient : public CryptoPubConnBase {
void OnConnect(); void OnConnect();
void OnError(); void OnError();
void SendHandshake();
void SendTunnelRequest(); void SendTunnelRequest();
struct event_base* event_base_; struct event_base* event_base_;