From 6a4a92f47a81f3a932c68116691e147eb2bacbe1 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sun, 8 Feb 2015 19:50:09 +0000 Subject: [PATCH] Make the handshake mirrored again, for common code and to support future key rotation. --- crypto.cc | 105 +++++++++++++++++++++++++++++++----------------------- crypto.h | 9 +++-- 2 files changed, 66 insertions(+), 48 deletions(-) diff --git a/crypto.cc b/crypto.cc index e2b3062..107be2c 100644 --- a/crypto.cc +++ b/crypto.cc @@ -104,7 +104,7 @@ CryptoKey::~CryptoKey() { void CryptoKey::WriteToFile(const std::string& filename) const { assert(is_set_); - int fd = open(filename.c_str(), O_WRONLY); + int fd = open(filename.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0400); assert(fd != -1); assert(write(fd, key_, key_bytes_) == key_bytes_); assert(!close(fd)); @@ -124,6 +124,10 @@ const unsigned char* CryptoKey::Key() const { return key_; } +bool CryptoKey::IsSet() const { + return is_set_; +} + unsigned char* CryptoKey::MutableKey() { assert(!is_set_); return key_; @@ -201,6 +205,25 @@ std::unique_ptr CryptoPubConnBase::BuildSecureHandshake() { return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake); } +std::unique_ptr CryptoPubConnBase::BuildHandshake() { + auto secure_handshake = BuildSecureHandshake(); + + std::unique_ptr handshake(new TLVNode(TLV_TYPE_HANDSHAKE)); + PublicKey public_key; + CryptoUtil::DerivePublicKey(secret_key_, &public_key); + handshake->AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString())); + handshake->AppendChild(secure_handshake.release()); + + return handshake; +} + +void CryptoPubConnBase::SendHandshake() { + auto handshake = BuildHandshake(); + std::string out; + handshake->Encode(&out); + bufferevent_write(bev_, out.data(), out.length()); +} + bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { assert(node.GetType() == TLV_TYPE_ENCRYPTED); @@ -223,6 +246,40 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { return true; } +bool CryptoPubConnBase::HandleHandshake(const TLVNode& node) { + if (node.GetType() != TLV_TYPE_HANDSHAKE) { + LogFatal("Protocol error (handshake; wrong message type)"); + return false; + } + + auto peer_public_key = node.FindChild(TLV_TYPE_PUBLIC_KEY); + if (!peer_public_key) { + LogFatal("Protocol error (handshake; no public key)"); + return false; + } + if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) { + LogFatal("Protocol error (handshake; wrong public key length)"); + return false; + } + if (peer_public_key_.IsSet()) { + // We're the client and already know the server public key; we expect these to match. + // Eventually, we can do smarter things here to allow key rotation. + if (peer_public_key_.AsString() != peer_public_key->GetValue()) { + LogFatal("Protocol error (handshake; public key mismatch)"); + return false; + } + } else { + peer_public_key_.FromString(peer_public_key->GetValue()); + } + auto encrypted = node.FindChild(TLV_TYPE_ENCRYPTED); + if (!encrypted) { + LogFatal("Protocol error (handshake; no encrypted portion)"); + return false; + } + + return HandleSecureHandshake(*encrypted); +} + void CryptoPubConnBase::EncryptSend(const TLVNode& node) { auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_secret_key_, peer_ephemeral_public_key_, node); std::string out; @@ -334,28 +391,7 @@ CryptoPubServerConnection::~CryptoPubServerConnection() { } void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { - if (decoded.GetType() != TLV_TYPE_HANDSHAKE) { - LogFatal("Protocol error (client handshake -- wrong message type)"); - return; - } - - auto peer_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY); - if (!peer_public_key) { - LogFatal("Protocol error (client handshake -- no public key)"); - return; - } - if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) { - LogFatal("Protocol error (client handshake -- wrong public key length)"); - return; - } - peer_public_key_.FromString(peer_public_key->GetValue()); - auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); - if (!encrypted) { - LogFatal("Protocol error (client handshake -- no encrypted portion)"); - return; - } - - if (!HandleSecureHandshake(*encrypted)) { + if (!HandleHandshake(decoded)) { return; } @@ -385,13 +421,6 @@ bool CryptoPubServerConnection::OnTunnelRequest(const TLVNode& message) { return true; } -void CryptoPubServerConnection::SendHandshake() { - auto handshake = BuildSecureHandshake(); - std::string out; - handshake->Encode(&out); - bufferevent_write(bev_, out.data(), out.length()); -} - void CryptoPubServerConnection::OnError_(struct bufferevent* bev, const short what, void* this__) { auto this_ = (CryptoPubServerConnection*)this__; this_->OnError(what); @@ -432,7 +461,7 @@ CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address } void CryptoPubClient::OnHandshake(const TLVNode& decoded) { - if (!HandleSecureHandshake(decoded)) { + if (!HandleHandshake(decoded)) { return; } @@ -463,20 +492,6 @@ void CryptoPubClient::OnConnect() { SendHandshake(); } -void CryptoPubClient::SendHandshake() { - auto secure_handshake = BuildSecureHandshake(); - - TLVNode handshake(TLV_TYPE_HANDSHAKE); - PublicKey public_key; - CryptoUtil::DerivePublicKey(secret_key_, &public_key); - handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString())); - handshake.AppendChild(secure_handshake.release()); - - std::string out; - handshake.Encode(&out); - bufferevent_write(bev_, out.data(), out.length()); -} - void CryptoPubClient::SendTunnelRequest() { TLVNode tunnel_request(TLV_TYPE_TUNNEL_REQUEST); for (auto channel_bitrate : channel_bitrates_) { diff --git a/crypto.h b/crypto.h index 8adc5cc..ea14837 100644 --- a/crypto.h +++ b/crypto.h @@ -15,6 +15,7 @@ class CryptoKey { void WriteToFile(const std::string& filename) const; const unsigned char* Key() const; + bool IsSet() const; unsigned char* MutableKey(); void MarkSet(); @@ -67,7 +68,12 @@ class CryptoPubConnBase : public CryptoBase { void LogFatal(const std::string& msg, void *obj=nullptr); std::unique_ptr BuildSecureHandshake(); + std::unique_ptr BuildHandshake(); + void SendHandshake(); + bool HandleSecureHandshake(const TLVNode& node); + bool HandleHandshake(const TLVNode& node); + void EncryptSend(const TLVNode& node); static void OnReadable_(struct bufferevent* bev, void* this__); @@ -123,8 +129,6 @@ class CryptoPubServerConnection : public CryptoPubConnBase { static void OnError_(struct bufferevent* bev, const short what, void* this__); void OnError(const short what); - void SendHandshake(); - friend CryptoPubServer; }; @@ -145,7 +149,6 @@ class CryptoPubClient : public CryptoPubConnBase { void OnConnect(); void OnError(); - void SendHandshake(); void SendTunnelRequest(); struct event_base* event_base_;