From d82cb789e3608cb266f91d8fee59b69b1dc346db Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 7 Feb 2015 15:07:34 -0800 Subject: [PATCH] Switch to unique_ptr for ownership clarity in places, combine more code. --- crypto.cc | 112 +++++++++++++++++++++++------------------------------- crypto.h | 7 +++- tlv.cc | 6 +-- tlv.h | 2 +- 4 files changed, 56 insertions(+), 71 deletions(-) diff --git a/crypto.cc b/crypto.cc index 6a4873f..b67a754 100644 --- a/crypto.cc +++ b/crypto.cc @@ -21,12 +21,10 @@ #define TLV_TYPE_DOWNSTREAM_BITRATE 0x0003 #define TLV_TYPE_ENCRYPTED 0x8000 -#define TLV_TYPE_CLIENT_HANDSHAKE 0x8001 -#define TLV_TYPE_CLIENT_HANDSHAKE_SECURE 0x8002 -#define TLV_TYPE_SERVER_HANDSHAKE 0x8003 -#define TLV_TYPE_SERVER_HANDSHAKE_SECURE 0x8004 -#define TLV_TYPE_TUNNEL_REQUEST 0x8005 -#define TLV_TYPE_CHANNEL 0x8006 +#define TLV_TYPE_HANDSHAKE 0x8001 +#define TLV_TYPE_HANDSHAKE_SECURE 0x8002 +#define TLV_TYPE_TUNNEL_REQUEST 0x8003 +#define TLV_TYPE_CHANNEL 0x8004 std::string CryptoUtil::BinToHex(const std::string& bin) { @@ -73,7 +71,7 @@ void CryptoUtil::WriteKeyToFile(const std::string& filename, const std::string& key_file << key; } -void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container) { +std::unique_ptr CryptoUtil::EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { assert(secret_key.length() == crypto_box_SECRETKEYBYTES); assert(public_key.length() == crypto_box_PUBLICKEYBYTES); @@ -88,14 +86,14 @@ void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::s unsigned char output[encrypted_bytes]; assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data())); - auto encrypted = new TLVNode(TLV_TYPE_ENCRYPTED); + std::unique_ptr encrypted(new TLVNode(TLV_TYPE_ENCRYPTED)); encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES))); encrypted->AppendChild(new TLVNode(TLV_TYPE_ENCRYPTED_BLOB, std::string((char*)output, encrypted_bytes))); - container->AppendChild(encrypted); + return encrypted; } -TLVNode* CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { +std::unique_ptr CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { assert(secret_key.length() == crypto_box_SECRETKEYBYTES); assert(public_key.length() == crypto_box_PUBLICKEYBYTES); assert(input.GetType() == TLV_TYPE_ENCRYPTED); @@ -136,6 +134,37 @@ CryptoConnBase::CryptoConnBase(const std::string& secret_key) : secret_key_(secret_key), state_(AWAITING_HANDSHAKE) {} +std::unique_ptr CryptoConnBase::BuildSecureHandshake() { + std::string ephemeral_public_key; + CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); + + TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE); + secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key)); + return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake); +} + +bool CryptoConnBase::HandleSecureHandshake(const TLVNode& node) { + assert(node.GetType() == TLV_TYPE_ENCRYPTED); + + std::unique_ptr decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node)); + if (!decrypted.get()) { + LogFatal() << "Protocol error (handshake; decryption failure)" << std::endl; + return false; + } + + auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY); + if (!peer_ephemeral_public_key) { + LogFatal() << "Protocol error (handshake; no ephemeral public key)" << std::endl; + return false; + } + peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue(); + if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) { + LogFatal() << "Protocol error (handshake; wrong ephemeral public key length)" << std::endl; + return false; + } + return true; +} + CryptoPubServer::CryptoPubServer(const std::string& secret_key) : secret_key_(secret_key), @@ -228,7 +257,7 @@ void CryptoPubServerConnection::OnReadable() { } void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { - if (decoded.GetType() != TLV_TYPE_CLIENT_HANDSHAKE) { + if (decoded.GetType() != TLV_TYPE_HANDSHAKE) { LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl; return; } @@ -249,20 +278,7 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { return; } - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted)); - if (!decrypted.get()) { - LogFatal() << "Protocol error (client handshake -- decryption failure)" << std::endl; - return; - } - - auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY); - if (!peer_ephemeral_public_key) { - LogFatal() << "Protocol error (client handshake -- no ephemeral public key)" << std::endl; - return; - } - peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue(); - if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) { - LogFatal() << "Protocol error (client handshake -- wrong ephemeral public key length)" << std::endl; + if (!HandleSecureHandshake(*encrypted)) { return; } @@ -273,16 +289,9 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { } void CryptoPubServerConnection::SendHandshake() { - std::string ephemeral_public_key; - CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); - - TLVNode handshake(TLV_TYPE_SERVER_HANDSHAKE); - TLVNode secure_handshake(TLV_TYPE_SERVER_HANDSHAKE_SECURE); - secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key)); - CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake); - + auto handshake = BuildSecureHandshake(); std::string out; - handshake.Encode(&out); + handshake->Encode(&out); bufferevent_write(bev_, out.data(), out.length()); } @@ -365,31 +374,7 @@ void CryptoPubClient::OnReadable() { } void CryptoPubClient::OnHandshake(const TLVNode& decoded) { - if (decoded.GetType() != TLV_TYPE_SERVER_HANDSHAKE) { - LogFatal() << "Protocol error (server handshake -- wrong message type)" << std::endl; - return; - } - - auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); - if (!encrypted) { - LogFatal() << "Protocol error (server handshake -- no encrypted portion)" << std::endl; - return; - } - - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted)); - if (!decrypted.get()) { - LogFatal() << "Protocol error (server handshake -- decryption failure)" << std::endl; - return; - } - - auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY); - if (!peer_ephemeral_public_key) { - LogFatal() << "Protocol error (server handshake -- no ephemeral public key)" << std::endl; - return; - } - peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue(); - if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) { - LogFatal() << "Protocol error (server handshake -- wrong ephemeral public key length)" << std::endl; + if (!HandleSecureHandshake(decoded)) { return; } @@ -414,16 +399,13 @@ void CryptoPubClient::OnConnect() { } void CryptoPubClient::SendHandshake() { - std::string ephemeral_public_key; - CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); + auto secure_handshake = BuildSecureHandshake(); - TLVNode handshake(TLV_TYPE_CLIENT_HANDSHAKE); + TLVNode handshake(TLV_TYPE_HANDSHAKE); std::string public_key; CryptoUtil::DerivePublicKey(secret_key_, &public_key); handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key)); - TLVNode secure_handshake(TLV_TYPE_CLIENT_HANDSHAKE_SECURE); - secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key)); - CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake); + handshake.AppendChild(secure_handshake.release()); std::string out; handshake.Encode(&out); diff --git a/crypto.h b/crypto.h index c9b183d..9d972a7 100644 --- a/crypto.h +++ b/crypto.h @@ -16,8 +16,8 @@ class CryptoUtil { static void ReadKeyFromFile(const std::string& filename, std::string* key); static void WriteKeyToFile(const std::string& filename, const std::string& key); - static void EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container); - static TLVNode *DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input); + static std::unique_ptr EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input); + static std::unique_ptr DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input); }; class CryptoBase { @@ -30,6 +30,9 @@ class CryptoConnBase : public CryptoBase { protected: CryptoConnBase(const std::string& secret_key); + std::unique_ptr BuildSecureHandshake(); + bool HandleSecureHandshake(const TLVNode& node); + enum { AWAITING_HANDSHAKE, READY, diff --git a/tlv.cc b/tlv.cc index 6addf06..d269808 100644 --- a/tlv.cc +++ b/tlv.cc @@ -43,7 +43,7 @@ void TLVNode::Encode(std::string *output) const { } } -TLVNode* TLVNode::Decode(const std::string& input) { +std::unique_ptr TLVNode::Decode(const std::string& input) { if (input.length() < sizeof(struct header)) { return nullptr; } @@ -71,10 +71,10 @@ TLVNode* TLVNode::Decode(const std::string& input) { cursor += sub_length; } - return container.release(); + return container; } else { // Scalar - return new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length))); + return std::unique_ptr(new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length)))); } } diff --git a/tlv.h b/tlv.h index d22a449..df421c1 100644 --- a/tlv.h +++ b/tlv.h @@ -9,7 +9,7 @@ class TLVNode { TLVNode(const uint16_t type, const std::string value); ~TLVNode(); - static TLVNode* Decode(const std::string& input); + static std::unique_ptr Decode(const std::string& input); void AppendChild(TLVNode* child);