Switch to unique_ptr for ownership clarity in places, combine more code.
This commit is contained in:
112
crypto.cc
112
crypto.cc
@@ -21,12 +21,10 @@
|
||||
#define TLV_TYPE_DOWNSTREAM_BITRATE 0x0003
|
||||
|
||||
#define TLV_TYPE_ENCRYPTED 0x8000
|
||||
#define TLV_TYPE_CLIENT_HANDSHAKE 0x8001
|
||||
#define TLV_TYPE_CLIENT_HANDSHAKE_SECURE 0x8002
|
||||
#define TLV_TYPE_SERVER_HANDSHAKE 0x8003
|
||||
#define TLV_TYPE_SERVER_HANDSHAKE_SECURE 0x8004
|
||||
#define TLV_TYPE_TUNNEL_REQUEST 0x8005
|
||||
#define TLV_TYPE_CHANNEL 0x8006
|
||||
#define TLV_TYPE_HANDSHAKE 0x8001
|
||||
#define TLV_TYPE_HANDSHAKE_SECURE 0x8002
|
||||
#define TLV_TYPE_TUNNEL_REQUEST 0x8003
|
||||
#define TLV_TYPE_CHANNEL 0x8004
|
||||
|
||||
|
||||
std::string CryptoUtil::BinToHex(const std::string& bin) {
|
||||
@@ -73,7 +71,7 @@ void CryptoUtil::WriteKeyToFile(const std::string& filename, const std::string&
|
||||
key_file << key;
|
||||
}
|
||||
|
||||
void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container) {
|
||||
std::unique_ptr<TLVNode> CryptoUtil::EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
|
||||
assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
|
||||
assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
|
||||
|
||||
@@ -88,14 +86,14 @@ void CryptoUtil::EncodeEncryptAppend(const std::string& secret_key, const std::s
|
||||
unsigned char output[encrypted_bytes];
|
||||
assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data()));
|
||||
|
||||
auto encrypted = new TLVNode(TLV_TYPE_ENCRYPTED);
|
||||
std::unique_ptr<TLVNode> encrypted(new TLVNode(TLV_TYPE_ENCRYPTED));
|
||||
encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES)));
|
||||
encrypted->AppendChild(new TLVNode(TLV_TYPE_ENCRYPTED_BLOB, std::string((char*)output, encrypted_bytes)));
|
||||
|
||||
container->AppendChild(encrypted);
|
||||
return encrypted;
|
||||
}
|
||||
|
||||
TLVNode* CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
|
||||
std::unique_ptr<TLVNode> CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) {
|
||||
assert(secret_key.length() == crypto_box_SECRETKEYBYTES);
|
||||
assert(public_key.length() == crypto_box_PUBLICKEYBYTES);
|
||||
assert(input.GetType() == TLV_TYPE_ENCRYPTED);
|
||||
@@ -136,6 +134,37 @@ CryptoConnBase::CryptoConnBase(const std::string& secret_key)
|
||||
: secret_key_(secret_key),
|
||||
state_(AWAITING_HANDSHAKE) {}
|
||||
|
||||
std::unique_ptr<TLVNode> CryptoConnBase::BuildSecureHandshake() {
|
||||
std::string ephemeral_public_key;
|
||||
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
|
||||
|
||||
TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE);
|
||||
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
|
||||
return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake);
|
||||
}
|
||||
|
||||
bool CryptoConnBase::HandleSecureHandshake(const TLVNode& node) {
|
||||
assert(node.GetType() == TLV_TYPE_ENCRYPTED);
|
||||
|
||||
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node));
|
||||
if (!decrypted.get()) {
|
||||
LogFatal() << "Protocol error (handshake; decryption failure)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
|
||||
if (!peer_ephemeral_public_key) {
|
||||
LogFatal() << "Protocol error (handshake; no ephemeral public key)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
|
||||
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
|
||||
LogFatal() << "Protocol error (handshake; wrong ephemeral public key length)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
CryptoPubServer::CryptoPubServer(const std::string& secret_key)
|
||||
: secret_key_(secret_key),
|
||||
@@ -228,7 +257,7 @@ void CryptoPubServerConnection::OnReadable() {
|
||||
}
|
||||
|
||||
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
|
||||
if (decoded.GetType() != TLV_TYPE_CLIENT_HANDSHAKE) {
|
||||
if (decoded.GetType() != TLV_TYPE_HANDSHAKE) {
|
||||
LogFatal() << "Protocol error (client handshake -- wrong message type)" << std::endl;
|
||||
return;
|
||||
}
|
||||
@@ -249,20 +278,7 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted));
|
||||
if (!decrypted.get()) {
|
||||
LogFatal() << "Protocol error (client handshake -- decryption failure)" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
|
||||
if (!peer_ephemeral_public_key) {
|
||||
LogFatal() << "Protocol error (client handshake -- no ephemeral public key)" << std::endl;
|
||||
return;
|
||||
}
|
||||
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
|
||||
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
|
||||
LogFatal() << "Protocol error (client handshake -- wrong ephemeral public key length)" << std::endl;
|
||||
if (!HandleSecureHandshake(*encrypted)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -273,16 +289,9 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
|
||||
}
|
||||
|
||||
void CryptoPubServerConnection::SendHandshake() {
|
||||
std::string ephemeral_public_key;
|
||||
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
|
||||
|
||||
TLVNode handshake(TLV_TYPE_SERVER_HANDSHAKE);
|
||||
TLVNode secure_handshake(TLV_TYPE_SERVER_HANDSHAKE_SECURE);
|
||||
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
|
||||
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
|
||||
|
||||
auto handshake = BuildSecureHandshake();
|
||||
std::string out;
|
||||
handshake.Encode(&out);
|
||||
handshake->Encode(&out);
|
||||
bufferevent_write(bev_, out.data(), out.length());
|
||||
}
|
||||
|
||||
@@ -365,31 +374,7 @@ void CryptoPubClient::OnReadable() {
|
||||
}
|
||||
|
||||
void CryptoPubClient::OnHandshake(const TLVNode& decoded) {
|
||||
if (decoded.GetType() != TLV_TYPE_SERVER_HANDSHAKE) {
|
||||
LogFatal() << "Protocol error (server handshake -- wrong message type)" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
|
||||
if (!encrypted) {
|
||||
LogFatal() << "Protocol error (server handshake -- no encrypted portion)" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_ptr<TLVNode> decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, *encrypted));
|
||||
if (!decrypted.get()) {
|
||||
LogFatal() << "Protocol error (server handshake -- decryption failure)" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto peer_ephemeral_public_key = decrypted->FindChild(TLV_TYPE_PUBLIC_KEY);
|
||||
if (!peer_ephemeral_public_key) {
|
||||
LogFatal() << "Protocol error (server handshake -- no ephemeral public key)" << std::endl;
|
||||
return;
|
||||
}
|
||||
peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue();
|
||||
if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) {
|
||||
LogFatal() << "Protocol error (server handshake -- wrong ephemeral public key length)" << std::endl;
|
||||
if (!HandleSecureHandshake(decoded)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -414,16 +399,13 @@ void CryptoPubClient::OnConnect() {
|
||||
}
|
||||
|
||||
void CryptoPubClient::SendHandshake() {
|
||||
std::string ephemeral_public_key;
|
||||
CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key);
|
||||
auto secure_handshake = BuildSecureHandshake();
|
||||
|
||||
TLVNode handshake(TLV_TYPE_CLIENT_HANDSHAKE);
|
||||
TLVNode handshake(TLV_TYPE_HANDSHAKE);
|
||||
std::string public_key;
|
||||
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
|
||||
handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key));
|
||||
TLVNode secure_handshake(TLV_TYPE_CLIENT_HANDSHAKE_SECURE);
|
||||
secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key));
|
||||
CryptoUtil::EncodeEncryptAppend(secret_key_, peer_public_key_, secure_handshake, &handshake);
|
||||
handshake.AppendChild(secure_handshake.release());
|
||||
|
||||
std::string out;
|
||||
handshake.Encode(&out);
|
||||
|
||||
7
crypto.h
7
crypto.h
@@ -16,8 +16,8 @@ class CryptoUtil {
|
||||
static void ReadKeyFromFile(const std::string& filename, std::string* key);
|
||||
static void WriteKeyToFile(const std::string& filename, const std::string& key);
|
||||
|
||||
static void EncodeEncryptAppend(const std::string& secret_key, const std::string& public_key, const TLVNode& input, TLVNode* container);
|
||||
static TLVNode *DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input);
|
||||
static std::unique_ptr<TLVNode> EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input);
|
||||
static std::unique_ptr<TLVNode> DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input);
|
||||
};
|
||||
|
||||
class CryptoBase {
|
||||
@@ -30,6 +30,9 @@ class CryptoConnBase : public CryptoBase {
|
||||
protected:
|
||||
CryptoConnBase(const std::string& secret_key);
|
||||
|
||||
std::unique_ptr<TLVNode> BuildSecureHandshake();
|
||||
bool HandleSecureHandshake(const TLVNode& node);
|
||||
|
||||
enum {
|
||||
AWAITING_HANDSHAKE,
|
||||
READY,
|
||||
|
||||
6
tlv.cc
6
tlv.cc
@@ -43,7 +43,7 @@ void TLVNode::Encode(std::string *output) const {
|
||||
}
|
||||
}
|
||||
|
||||
TLVNode* TLVNode::Decode(const std::string& input) {
|
||||
std::unique_ptr<TLVNode> TLVNode::Decode(const std::string& input) {
|
||||
if (input.length() < sizeof(struct header)) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -71,10 +71,10 @@ TLVNode* TLVNode::Decode(const std::string& input) {
|
||||
cursor += sub_length;
|
||||
}
|
||||
|
||||
return container.release();
|
||||
return container;
|
||||
} else {
|
||||
// Scalar
|
||||
return new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length)));
|
||||
return std::unique_ptr<TLVNode>(new TLVNode(htons(header->type), input.substr(sizeof(*header), htons(header->value_length))));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user