Switch to unique_ptr for ownership clarity in places, combine more code.

This commit is contained in:
Ian Gulliver
2015-02-07 15:07:34 -08:00
parent 844db000f6
commit d82cb789e3
4 changed files with 56 additions and 71 deletions

112
crypto.cc
View File

@@ -21,12 +21,10 @@
#define TLV_TYPE_DOWNSTREAM_BITRATE 0x0003 #define TLV_TYPE_DOWNSTREAM_BITRATE 0x0003
#define TLV_TYPE_ENCRYPTED 0x8000 #define TLV_TYPE_ENCRYPTED 0x8000
#define TLV_TYPE_CLIENT_HANDSHAKE 0x8001 #define TLV_TYPE_HANDSHAKE 0x8001
#define TLV_TYPE_CLIENT_HANDSHAKE_SECURE 0x8002 #define TLV_TYPE_HANDSHAKE_SECURE 0x8002
#define TLV_TYPE_SERVER_HANDSHAKE 0x8003 #define TLV_TYPE_TUNNEL_REQUEST 0x8003
#define TLV_TYPE_SERVER_HANDSHAKE_SECURE 0x8004 #define TLV_TYPE_CHANNEL 0x8004
#define TLV_TYPE_TUNNEL_REQUEST 0x8005
#define TLV_TYPE_CHANNEL 0x8006
std::string CryptoUtil::BinToHex(const std::string& bin) { std::string CryptoUtil::BinToHex(const std::string& bin) {
@@ -73,7 +71,7 @@ void CryptoUtil::WriteKeyToFile(const std::string& filename, const std::string&
key_file << key; key_file << key;
} }
void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container) { std::unique_ptr<TLVNode> CryptoUtil::EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
assert(secret_key.length() == crypto_box_SECRETKEYBYTES); assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
assert(public_key.length() == crypto_box_PUBLICKEYBYTES); assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
@@ -88,14 +86,14 @@ void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::s
unsigned char output[encrypted_bytes]; unsigned char output[encrypted_bytes];
assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data())); assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data()));
auto encrypted = new TLVNode(TLV_TYPE_ENCRYPTED); std::unique_ptr<TLVNode> encrypted(new TLVNode(TLV_TYPE_ENCRYPTED));
encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES))); encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES)));
encrypted->AppendChild(new TLVNode(TLV_TYPE_ENCRYPTED_BLOB, std::string((char*)output, encrypted_bytes))); encrypted->AppendChild(new TLVNode(TLV_TYPE_ENCRYPTED_BLOB, std::string((char*)output, encrypted_bytes)));
container->AppendChild(encrypted); return encrypted;
} }
TLVNode* CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { std::unique_ptr<TLVNode> CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
assert(secret_key.length() == crypto_box_SECRETKEYBYTES); assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
assert(public_key.length() == crypto_box_PUBLICKEYBYTES); assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
assert(input.GetType() == TLV_TYPE_ENCRYPTED); assert(input.GetType() == TLV_TYPE_ENCRYPTED);
@@ -136,6 +134,37 @@ CryptoConnBase::CryptoConnBase(const std::string& secret_key)
: secret_key_(secret_key), : secret_key_(secret_key),
state_(AWAITING_HANDSHAKE) {} state_(AWAITING_HANDSHAKE) {}
std::unique_ptr<TLVNode> CryptoConnBase::BuildSecureHandshake() {
std::string ephemeral_public_key;
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE);
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake);
}
bool CryptoConnBase::HandleSecureHandshake(const TLVNode& node) {
assert(node.GetType() == TLV_TYPE_ENCRYPTED);
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node));
if (!decrypted.get()) {
LogFatal() << "Protocol error (handshake; decryption failure)" << std::endl;
return false;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (handshake; no ephemeral public key)" << std::endl;
return false;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (handshake; wrong ephemeral public key length)" << std::endl;
return false;
}
return true;
}
CryptoPubServer::CryptoPubServer(const std::string& secret_key) CryptoPubServer::CryptoPubServer(const std::string& secret_key)
: secret_key_(secret_key), : secret_key_(secret_key),
@@ -228,7 +257,7 @@ void CryptoPubServerConnection::OnReadable() {
} }
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_CLIENT_HANDSHAKE) { if (decoded.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl; LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl;
return; return;
} }
@@ -249,20 +278,7 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
return; return;
} }
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted)); if (!HandleSecureHandshake(*encrypted)) {
if (!decrypted.get()) {
LogFatal() << "Protocol error (client handshake -- decryption failure)" << std::endl;
return;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (client handshake -- no ephemeral public key)" << std::endl;
return;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (client handshake -- wrong ephemeral public key length)" << std::endl;
return; return;
} }
@@ -273,16 +289,9 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
} }
void CryptoPubServerConnection::SendHandshake() { void CryptoPubServerConnection::SendHandshake() {
std::string ephemeral_public_key; auto handshake = BuildSecureHandshake();
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
TLVNode handshake(TLV_TYPE_SERVER_HANDSHAKE);
TLVNode secure_handshake(TLV_TYPE_SERVER_HANDSHAKE_SECURE);
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
std::string out; std::string out;
handshake.Encode(&out); handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length()); bufferevent_write(bev_, out.data(), out.length());
} }
@@ -365,31 +374,7 @@ void CryptoPubClient::OnReadable() {
} }
void CryptoPubClient::OnHandshake(const TLVNode& decoded) { void CryptoPubClient::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_SERVER_HANDSHAKE) { if (!HandleSecureHandshake(decoded)) {
LogFatal() << "Protocol error (server handshake -- wrong message type)" << std::endl;
return;
}
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal() << "Protocol error (server handshake -- no encrypted portion)" << std::endl;
return;
}
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted));
if (!decrypted.get()) {
LogFatal() << "Protocol error (server handshake -- decryption failure)" << std::endl;
return;
}
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_ephemeral_public_key) {
LogFatal() << "Protocol error (server handshake -- no ephemeral public key)" << std::endl;
return;
}
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
LogFatal() << "Protocol error (server handshake -- wrong ephemeral public key length)" << std::endl;
return; return;
} }
@@ -414,16 +399,13 @@ void CryptoPubClient::OnConnect() {
} }
void CryptoPubClient::SendHandshake() { void CryptoPubClient::SendHandshake() {
std::string ephemeral_public_key; auto secure_handshake = BuildSecureHandshake();
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
TLVNode handshake(TLV_TYPE_CLIENT_HANDSHAKE); TLVNode handshake(TLV_TYPE_HANDSHAKE);
std::string public_key; std::string public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key); CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key)); handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key));
TLVNode secure_handshake(TLV_TYPE_CLIENT_HANDSHAKE_SECURE); handshake.AppendChild(secure_handshake.release());
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
std::string out; std::string out;
handshake.Encode(&out); handshake.Encode(&out);

View File

@@ -16,8 +16,8 @@ class CryptoUtil {
static void ReadKeyFromFile(const std::string& filename, std::string* key); static void ReadKeyFromFile(const std::string& filename, std::string* key);
static void WriteKeyToFile(const std::string& filename, const std::string& key); static void WriteKeyToFile(const std::string& filename, const std::string& key);
static void EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container); static std::unique_ptr<TLVNode> EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input);
static TLVNode *DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input); static std::unique_ptr<TLVNode> DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input);
}; };
class CryptoBase { class CryptoBase {
@@ -30,6 +30,9 @@ class CryptoConnBase : public CryptoBase {
protected: protected:
CryptoConnBase(const std::string& secret_key); CryptoConnBase(const std::string& secret_key);
std::unique_ptr<TLVNode> BuildSecureHandshake();
bool HandleSecureHandshake(const TLVNode& node);
enum { enum {
AWAITING_HANDSHAKE, AWAITING_HANDSHAKE,
READY, READY,

6
tlv.cc
View File

@@ -43,7 +43,7 @@ void TLVNode::Encode(std::string *output) const {
} }
} }
TLVNode* TLVNode::Decode(const std::string& input) { std::unique_ptr<TLVNode> TLVNode::Decode(const std::string& input) {
if (input.length() < sizeof(struct header)) { if (input.length() < sizeof(struct header)) {
return nullptr; return nullptr;
} }
@@ -71,10 +71,10 @@ TLVNode* TLVNode::Decode(const std::string& input) {
cursor += sub_length; cursor += sub_length;
} }
return container.release(); return container;
} else { } else {
// Scalar // Scalar
return new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length))); return std::unique_ptr<TLVNode>(new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length))));
} }
} }

2
tlv.h
View File

@@ -9,7 +9,7 @@ class TLVNode {
TLVNode(const uint16_t type, const std::string value); TLVNode(const uint16_t type, const std::string value);
~TLVNode(); ~TLVNode();
static TLVNode* Decode(const std::string& input); static std::unique_ptr<TLVNode> Decode(const std::string& input);
void AppendChild(TLVNode* child); void AppendChild(TLVNode* child);