From 4274c7ca1df3584d4aad390167a5b9cf97dfeba2 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 7 Feb 2015 15:25:22 -0800 Subject: [PATCH] More code dedupe. --- crypto.cc | 130 +++++++++++++++++++++++------------------------------- crypto.h | 18 +++++--- 2 files changed, 66 insertions(+), 82 deletions(-) diff --git a/crypto.cc b/crypto.cc index b67a754..89fa86c 100644 --- a/crypto.cc +++ b/crypto.cc @@ -165,6 +165,44 @@ bool CryptoConnBase::HandleSecureHandshake(const TLVNode& node) { return true; } +void CryptoConnBase::OnReadable_(struct bufferevent* bev, void* this__) { + auto this_ = (CryptoConnBase*)this__; + this_->OnReadable(); +} + +void CryptoConnBase::OnReadable() { + char buf[UINT16_MAX]; + int bytes = bufferevent_read(bev_, buf, UINT16_MAX); + const std::string input(buf, bytes); + std::unique_ptr decoded(TLVNode::Decode(input)); + + if (!decoded.get()) { + // TODO: re-buffer? + return; + } + + if (state_ == AWAITING_HANDSHAKE) { + OnHandshake(*decoded); + return; + } + + if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { + LogFatal() << "Protocol error (wrong message type)" << std::endl; + return; + } + + std::unique_ptr decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded)); + if (!decrypted.get()) { + LogFatal() << "Protocol error (decryption failure)" << std::endl; + return; + } + + if (!OnMessage(*decrypted)) { + LogFatal() << "Protocol error (message handling)" << std::endl; + return; + } +} + CryptoPubServer::CryptoPubServer(const std::string& secret_key) : secret_key_(secret_key), @@ -211,8 +249,8 @@ void CryptoPubServer::Loop() { CryptoPubServerConnection::CryptoPubServerConnection(struct bufferevent* bev, const std::string& secret_key) - : CryptoConnBase(secret_key), - bev_(bev) { + : CryptoConnBase(secret_key) { + bev_ = bev; } CryptoPubServerConnection::~CryptoPubServerConnection() { @@ -220,42 +258,6 @@ CryptoPubServerConnection::~CryptoPubServerConnection() { bufferevent_free(bev_); } -void CryptoPubServerConnection::OnReadable_(struct bufferevent* bev, void* this__) { - auto this_ = (CryptoPubServerConnection*)this__; - this_->OnReadable(); -} - -void CryptoPubServerConnection::OnReadable() { - char buf[UINT16_MAX]; - int bytes = bufferevent_read(bev_, buf, UINT16_MAX); - const std::string input(buf, bytes); - std::unique_ptr decoded(TLVNode::Decode(input)); - - if (!decoded.get()) { - // TODO: re-buffer? - return; - } - - if (state_ == AWAITING_HANDSHAKE) { - OnHandshake(*decoded); - return; - } - - if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { - LogFatal() << "Protocol error (wrong message type)" << std::endl; - return; - } - - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded)); - if (!decrypted.get()) { - LogFatal() << "Protocol error (decryption failure)" << std::endl; - return; - } - - switch (decrypted->GetType()) { - } -} - void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { if (decoded.GetType() != TLV_TYPE_HANDSHAKE) { LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl; @@ -288,6 +290,13 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { Log() << "Handshake successful (client ID: " << CryptoUtil::BinToHex(peer_public_key_) << ")" << std::endl; } +bool CryptoPubServerConnection::OnMessage(const TLVNode& message) { + switch (message.GetType()) { + default: + return false; + } +} + void CryptoPubServerConnection::SendHandshake() { auto handshake = BuildSecureHandshake(); std::string out; @@ -308,8 +317,8 @@ void CryptoPubServerConnection::OnError(const short what) { CryptoPubClient::CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const std::string& secret_key, const std::string& server_public_key, const std::list& channel_bitrates) : CryptoConnBase(secret_key), event_base_(event_base_new()), - bev_(bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE)), channel_bitrates_(channel_bitrates) { + bev_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); peer_public_key_ = server_public_key; assert(secret_key_.length() == crypto_box_SECRETKEYBYTES); assert(peer_public_key_.length() == crypto_box_PUBLICKEYBYTES); @@ -337,42 +346,6 @@ CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address return ret; } -void CryptoPubClient::OnReadable_(struct bufferevent* bev, void* this__) { - auto this_ = (CryptoPubClient*)this__; - this_->OnReadable(); -} - -void CryptoPubClient::OnReadable() { - char buf[UINT16_MAX]; - int bytes = bufferevent_read(bev_, buf, UINT16_MAX); - const std::string input(buf, bytes); - std::unique_ptr decoded(TLVNode::Decode(input)); - - if (!decoded.get()) { - // TODO: re-buffer? - return; - } - - if (state_ == AWAITING_HANDSHAKE) { - OnHandshake(*decoded); - return; - } - - if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { - LogFatal() << "Protocol error (unexpected message type)" << std::endl; - return; - } - - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded)); - if (!decrypted.get()) { - LogFatal() << "Protocol error (decryption failure)" << std::endl; - return; - } - - switch (decrypted->GetType()) { - } -} - void CryptoPubClient::OnHandshake(const TLVNode& decoded) { if (!HandleSecureHandshake(decoded)) { return; @@ -384,6 +357,13 @@ void CryptoPubClient::OnHandshake(const TLVNode& decoded) { SendTunnelRequest(); } +bool CryptoPubClient::OnMessage(const TLVNode& message) { + switch (message.GetType()) { + default: + return false; + } +} + void CryptoPubClient::OnConnectOrError_(struct bufferevent* bev, const short what, void* this__) { auto this_ = (CryptoPubClient*)this__; if (what == BEV_EVENT_CONNECTED) { diff --git a/crypto.h b/crypto.h index 9d972a7..64c554e 100644 --- a/crypto.h +++ b/crypto.h @@ -33,11 +33,18 @@ class CryptoConnBase : public CryptoBase { std::unique_ptr BuildSecureHandshake(); bool HandleSecureHandshake(const TLVNode& node); + static void OnReadable_(struct bufferevent* bev, void* this__); + void OnReadable(); + virtual void OnHandshake(const TLVNode& decoded) = 0; + virtual bool OnMessage(const TLVNode& node) = 0; + enum { AWAITING_HANDSHAKE, READY, } state_; + struct bufferevent* bev_; + const std::string secret_key_; std::string peer_public_key_; std::string ephemeral_secret_key_; @@ -68,16 +75,14 @@ class CryptoPubServerConnection : public CryptoConnBase { ~CryptoPubServerConnection(); private: - static void OnReadable_(struct bufferevent* bev, void* this__); - void OnReadable(); void OnHandshake(const TLVNode& decoded); + bool OnMessage(const TLVNode& node); + static void OnError_(struct bufferevent* bev, const short what, void* this__); void OnError(const short what); void SendHandshake(); - struct bufferevent* bev_; - friend CryptoPubServer; }; @@ -91,9 +96,9 @@ class CryptoPubClient : public CryptoConnBase { void Loop(); private: - static void OnReadable_(struct bufferevent* bev, void* this__); - void OnReadable(); void OnHandshake(const TLVNode& decoded); + bool OnMessage(const TLVNode& node); + static void OnConnectOrError_(struct bufferevent* bev, const short what, void* this__); void OnConnect(); void OnError(); @@ -102,7 +107,6 @@ class CryptoPubClient : public CryptoConnBase { void SendTunnelRequest(); struct event_base* event_base_; - struct bufferevent* bev_; const std::list channel_bitrates_; };