Switch to unique_ptr for ownership clarity in places, combine more code.

This commit is contained in:
Ian Gulliver
2015-02-07 15:07:34 -08:00
parent 844db000f6
commit d82cb789e3
4 changed files with 56 additions and 71 deletions

112
crypto.cc
View File

@@ -21,12 +21,10 @@
#define TLV_TYPE_DOWNSTREAM_BITRATE 0x0003
#define TLV_TYPE_ENCRYPTED 0x8000
#define TLV_TYPE_CLIENT_HANDSHAKE 0x8001
#define TLV_TYPE_CLIENT_HANDSHAKE_SECURE 0x8002
#define TLV_TYPE_SERVER_HANDSHAKE 0x8003
#define TLV_TYPE_SERVER_HANDSHAKE_SECURE 0x8004
#define TLV_TYPE_TUNNEL_REQUEST 0x8005
#define TLV_TYPE_CHANNEL 0x8006
#define TLV_TYPE_HANDSHAKE 0x8001
#define TLV_TYPE_HANDSHAKE_SECURE 0x8002
#define TLV_TYPE_TUNNEL_REQUEST 0x8003
#define TLV_TYPE_CHANNEL 0x8004
std::string CryptoUtil::BinToHex(const std::string& bin) {
@@ -73,7 +71,7 @@ void CryptoUtil::WriteKeyToFile(const std::string& filename, const std::string&
key_file << key;
}
void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container) {
std::unique_ptr<TLVNode> CryptoUtil::EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
@@ -88,14 +86,14 @@ void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::s
unsigned char output[encrypted_bytes];
assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data()));
auto encrypted = new TLVNode(TLV_TYPE_ENCRYPTED);
std::unique_ptr<TLVNode> encrypted(new TLVNode(TLV_TYPE_ENCRYPTED));
encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES)));
encrypted->AppendChild(new TLVNode(TLV_TYPE_ENCRYPTED_BLOB, std::string((char*)output, encrypted_bytes)));
container->AppendChild(encrypted);
return encrypted;
}
TLVNode* CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
std::unique_ptr<TLVNode> CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
assert(input.GetType() == TLV_TYPE_ENCRYPTED);
@@ -136,6 +134,37 @@ CryptoConnBase::CryptoConnBase(const std::string& secret_key)
: secret_key_(secret_key),
state_(AWAITING_HANDSHAKE) {}
std::unique_ptr<TLVNode> CryptoConnBase::BuildSecureHandshake() {
std::string ephemeral_public_key;
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE);
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake);
}
bool CryptoConnBase::HandleSecureHandshake(const TLVNode& node) {
assert(node.GetType() == TLV_TYPE_ENCRYPTED);
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node));
if (!decrypted.get()) {
LogFatal() << "Protocol error (handshake; decryption failure)" << std::endl;
return false;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (handshake; no ephemeral public key)" << std::endl;
return false;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (handshake; wrong ephemeral public key length)" << std::endl;
return false;
}
return true;
}
CryptoPubServer::CryptoPubServer(const std::string& secret_key)
: secret_key_(secret_key),
@@ -228,7 +257,7 @@ void CryptoPubServerConnection::OnReadable() {
}
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_CLIENT_HANDSHAKE) {
if (decoded.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl;
return;
}
@@ -249,20 +278,7 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
return;
}
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted));
if (!decrypted.get()) {
LogFatal() << "Protocol error (client handshake -- decryption failure)" << std::endl;
return;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (client handshake -- no ephemeral public key)" << std::endl;
return;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (client handshake -- wrong ephemeral public key length)" << std::endl;
if (!HandleSecureHandshake(*encrypted)) {
return;
}
@@ -273,16 +289,9 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
}
void CryptoPubServerConnection::SendHandshake() {
std::string ephemeral_public_key;
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
TLVNode handshake(TLV_TYPE_SERVER_HANDSHAKE);
TLVNode secure_handshake(TLV_TYPE_SERVER_HANDSHAKE_SECURE);
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
auto handshake = BuildSecureHandshake();
std::string out;
handshake.Encode(&out);
handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
@@ -365,31 +374,7 @@ void CryptoPubClient::OnReadable() {
}
void CryptoPubClient::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_SERVER_HANDSHAKE) {
LogFatal() << "Protocol error (server handshake -- wrong message type)" << std::endl;
return;
}
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal() << "Protocol error (server handshake -- no encrypted portion)" << std::endl;
return;
}
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted));
if (!decrypted.get()) {
LogFatal() << "Protocol error (server handshake -- decryption failure)" << std::endl;
return;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (server handshake -- no ephemeral public key)" << std::endl;
return;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (server handshake -- wrong ephemeral public key length)" << std::endl;
if (!HandleSecureHandshake(decoded)) {
return;
}
@@ -414,16 +399,13 @@ void CryptoPubClient::OnConnect() {
}
void CryptoPubClient::SendHandshake() {
std::string ephemeral_public_key;
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
auto secure_handshake = BuildSecureHandshake();
TLVNode handshake(TLV_TYPE_CLIENT_HANDSHAKE);
TLVNode handshake(TLV_TYPE_HANDSHAKE);
std::string public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key));
TLVNode secure_handshake(TLV_TYPE_CLIENT_HANDSHAKE_SECURE);
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
handshake.AppendChild(secure_handshake.release());
std::string out;
handshake.Encode(&out);