diff --git a/crypto.cc b/crypto.cc index e44bc3e..7bd5dbe 100644 --- a/crypto.cc +++ b/crypto.cc @@ -194,21 +194,35 @@ void CryptoPubServerConnection::OnReadable() { char buf[UINT16_MAX]; int bytes = bufferevent_read(bev_, buf, UINT16_MAX); const std::string input(buf, bytes); - - if (state_ == AWAITING_HANDSHAKE) { - OnHandshake(input); - return; - } -} - -void CryptoPubServerConnection::OnHandshake(const std::string& input) { std::unique_ptr decoded(TLVNode::Decode(input)); + if (!decoded.get()) { // TODO: re-buffer? return; } - auto client_public_key = decoded->FindChild(TLV_TYPE_PUBLIC_KEY); + if (state_ == AWAITING_HANDSHAKE) { + OnHandshake(*decoded); + return; + } + + if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { + LogFatal() << "Protocol error (unexpected message type)" << std::endl; + return; + } + + std::unique_ptr decrypted(DecryptDecode(ephemeral_secret_key_, client_ephemeral_public_key_, *decoded)); + if (!decrypted.get()) { + LogFatal() << "Protocol error (decryption failure)" << std::endl; + return; + } + + switch (decrypted->GetType()) { + } +} + +void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { + auto client_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY); if (!client_public_key) { LogFatal() << "Protocol error (client handshake -- no public key)" << std::endl; return; @@ -218,7 +232,7 @@ void CryptoPubServerConnection::OnHandshake(const std::string& input) { LogFatal() << "Protocol error (client handshake -- wrong public key length)" << std::endl; return; } - auto encrypted = decoded->FindChild(TLV_TYPE_ENCRYPTED); + auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); if (!encrypted) { LogFatal() << "Protocol error (client handshake -- no encrypted portion)" << std::endl; return; @@ -309,21 +323,35 @@ void CryptoPubClient::OnReadable() { char buf[UINT16_MAX]; int bytes = bufferevent_read(bev_, buf, UINT16_MAX); const std::string input(buf, bytes); - - if (state_ == AWAITING_HANDSHAKE) { - OnHandshake(input); - return; - } -} - -void CryptoPubClient::OnHandshake(const std::string& input) { std::unique_ptr decoded(TLVNode::Decode(input)); + if (!decoded.get()) { // TODO: re-buffer? return; } - auto encrypted = decoded->FindChild(TLV_TYPE_ENCRYPTED); + if (state_ == AWAITING_HANDSHAKE) { + OnHandshake(*decoded); + return; + } + + if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { + LogFatal() << "Protocol error (unexpected message type)" << std::endl; + return; + } + + std::unique_ptr decrypted(DecryptDecode(ephemeral_secret_key_, server_ephemeral_public_key_, *decoded)); + if (!decrypted.get()) { + LogFatal() << "Protocol error (decryption failure)" << std::endl; + return; + } + + switch (decrypted->GetType()) { + } +} + +void CryptoPubClient::OnHandshake(const TLVNode& decoded) { + auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); if (!encrypted) { LogFatal() << "Protocol error (server handshake -- no encrypted portion)" << std::endl; return; diff --git a/crypto.h b/crypto.h index e9a9e56..2a0501c 100644 --- a/crypto.h +++ b/crypto.h @@ -51,7 +51,7 @@ class CryptoPubServerConnection : public CryptoBase { private: static void OnReadable_(struct bufferevent* bev, void* this__); void OnReadable(); - void OnHandshake(const std::string& input); + void OnHandshake(const TLVNode& decoded); static void OnError_(struct bufferevent* bev, const short what, void* this__); void OnError(const short what); @@ -82,7 +82,7 @@ class CryptoPubClient : public CryptoBase { private: static void OnReadable_(struct bufferevent* bev, void* this__); void OnReadable(); - void OnHandshake(const std::string& input); + void OnHandshake(const TLVNode& decoded); static void OnConnectOrError_(struct bufferevent* bev, const short what, void* this__); void OnConnect(); void OnError();