Error handling fixes.

This commit is contained in:
Ian Gulliver
2015-02-07 15:56:29 -08:00
parent c93121ddf2
commit 578615a99e
2 changed files with 24 additions and 20 deletions

View File

@@ -123,17 +123,21 @@ std::ostream& CryptoBase::Log(void *obj) {
return std::cerr << buf; return std::cerr << buf;
} }
std::ostream& CryptoBase::LogFatal(void *obj) {
std::ostream& ret = Log(obj);
delete this;
return ret;
}
CryptoPubConnBase::CryptoPubConnBase(const std::string& secret_key) CryptoPubConnBase::CryptoPubConnBase(const std::string& secret_key)
: secret_key_(secret_key), : secret_key_(secret_key),
state_(AWAITING_HANDSHAKE) {} state_(AWAITING_HANDSHAKE) {}
CryptoPubConnBase::~CryptoPubConnBase() {
bufferevent_free(bev_);
}
void CryptoPubConnBase::LogFatal(const std::string& msg, void *obj) {
Log(obj) << msg << std::endl;
delete this;
return;
}
std::unique_ptr<TLVNode> CryptoPubConnBase::BuildSecureHandshake() { std::unique_ptr<TLVNode> CryptoPubConnBase::BuildSecureHandshake() {
std::string ephemeral_public_key; std::string ephemeral_public_key;
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
@@ -148,18 +152,18 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) {
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node)); std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node));
if (!decrypted.get()) { if (!decrypted.get()) {
LogFatal() << "Protocol error (handshake; decryption failure)" << std::endl; LogFatal("Protocol error (handshake; decryption failure)");
return false; return false;
} }
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY); auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) { if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (handshake; no ephemeral public key)" << std::endl; LogFatal("Protocol error (handshake; no ephemeral public key)");
return false; return false;
} }
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue(); peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) { if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (handshake; wrong ephemeral public key length)" << std::endl; LogFatal("Protocol error (handshake; wrong ephemeral public key length)");
return false; return false;
} }
return true; return true;
@@ -194,18 +198,18 @@ void CryptoPubConnBase::OnReadable() {
} }
if (decoded->GetType() != TLV_TYPE_ENCRYPTED) { if (decoded->GetType() != TLV_TYPE_ENCRYPTED) {
LogFatal() << "Protocol error (wrong message type)" << std::endl; LogFatal("Protocol error (wrong message type)");
return; return;
} }
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded)); std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded));
if (!decrypted.get()) { if (!decrypted.get()) {
LogFatal() << "Protocol error (decryption failure)" << std::endl; LogFatal("Protocol error (decryption failure)");
return; return;
} }
if (!OnMessage(*decrypted)) { if (!OnMessage(*decrypted)) {
LogFatal() << "Protocol error (message handling)" << std::endl; LogFatal("Protocol error (message handling)");
return; return;
} }
} }
@@ -262,28 +266,27 @@ CryptoPubServerConnection::CryptoPubServerConnection(struct bufferevent* bev, co
CryptoPubServerConnection::~CryptoPubServerConnection() { CryptoPubServerConnection::~CryptoPubServerConnection() {
Log() << "Connection closed" << std::endl; Log() << "Connection closed" << std::endl;
bufferevent_free(bev_);
} }
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_HANDSHAKE) { if (decoded.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl; LogFatal("Protocol error (client handshake -- wrong message type)");
return; return;
} }
auto peer_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY); auto peer_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_public_key) { if (!peer_public_key) {
LogFatal() << "Protocol error (client handshake -- no public key)" << std::endl; LogFatal("Protocol error (client handshake -- no public key)");
return; return;
} }
peer_public_key_ = peer_public_key->GetValue(); peer_public_key_ = peer_public_key->GetValue();
if (peer_public_key_.length() != crypto_box_PUBLICKEYBYTES) { if (peer_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (client handshake -- wrong public key length)" << std::endl; LogFatal("Protocol error (client handshake -- wrong public key length)");
return; return;
} }
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) { if (!encrypted) {
LogFatal() << "Protocol error (client handshake -- no encrypted portion)" << std::endl; LogFatal("Protocol error (client handshake -- no encrypted portion)");
return; return;
} }
@@ -337,7 +340,6 @@ CryptoPubClient::CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const
} }
CryptoPubClient::~CryptoPubClient() { CryptoPubClient::~CryptoPubClient() {
bufferevent_free(bev_);
event_base_free(event_base_); event_base_free(event_base_);
} }

View File

@@ -21,14 +21,16 @@ class CryptoUtil {
}; };
class CryptoBase { class CryptoBase {
public: protected:
std::ostream& Log(void *obj=nullptr); std::ostream& Log(void *obj=nullptr);
std::ostream& LogFatal(void *obj=nullptr);
}; };
class CryptoPubConnBase : public CryptoBase { class CryptoPubConnBase : public CryptoBase {
protected: protected:
CryptoPubConnBase(const std::string& secret_key); CryptoPubConnBase(const std::string& secret_key);
virtual ~CryptoPubConnBase();
void LogFatal(const std::string& msg, void *obj=nullptr);
std::unique_ptr<TLVNode> BuildSecureHandshake(); std::unique_ptr<TLVNode> BuildSecureHandshake();
bool HandleSecureHandshake(const TLVNode& node); bool HandleSecureHandshake(const TLVNode& node);