diff --git a/crypto.cc b/crypto.cc index 25b1425..108ed57 100644 --- a/crypto.cc +++ b/crypto.cc @@ -49,7 +49,12 @@ void CryptoUtil::DerivePublicKey(const SecretKey& secret_key, PublicKey* public_ public_key->MarkSet(); } -std::unique_ptr CryptoUtil::EncodeEncrypt(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input) { +void CryptoUtil::PrecalculateKey(const SecretKey& secret_key, const PublicKey& public_key, PrecalcKey* precalc_key) { + assert(!crypto_box_beforenm(precalc_key->MutableKey(), public_key.Key(), secret_key.Key())); + precalc_key->MarkSet(); +} + +std::unique_ptr CryptoUtil::EncodeEncrypt(const PrecalcKey& precalc_key, const TLVNode& input) { std::string encoded; input.Encode(&encoded); @@ -59,7 +64,7 @@ std::unique_ptr CryptoUtil::EncodeEncrypt(const SecretKey& secret_key, randombytes_buf(nonce, crypto_box_NONCEBYTES); unsigned char output[encrypted_bytes]; - assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, public_key.Key(), secret_key.Key())); + assert(!crypto_box_easy_afternm(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, precalc_key.Key())); std::unique_ptr encrypted(new TLVNode(TLV_TYPE_ENCRYPTED)); encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES))); @@ -68,7 +73,7 @@ std::unique_ptr CryptoUtil::EncodeEncrypt(const SecretKey& secret_key, return encrypted; } -std::unique_ptr CryptoUtil::DecryptDecode(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input) { +std::unique_ptr CryptoUtil::DecryptDecode(const PrecalcKey& precalc_key, const TLVNode& input) { assert(input.GetType() == TLV_TYPE_ENCRYPTED); auto nonce = input.FindChild(TLV_TYPE_NONCE); @@ -83,7 +88,7 @@ std::unique_ptr CryptoUtil::DecryptDecode(const SecretKey& secret_key, size_t decrypted_bytes = encrypted->GetValue().length() - crypto_box_MACBYTES; unsigned char output[decrypted_bytes]; - if (crypto_box_open_easy(output, (const unsigned char*)encrypted->GetValue().data(), encrypted->GetValue().length(), (const unsigned char*)nonce->GetValue().data(), public_key.Key(), secret_key.Key())) { + if (crypto_box_open_easy_afternm(output, (const unsigned char*)encrypted->GetValue().data(), encrypted->GetValue().length(), (const unsigned char*)nonce->GetValue().data(), precalc_key.Key())) { return nullptr; } @@ -99,10 +104,13 @@ CryptoKey::CryptoKey(const size_t key_bytes) } CryptoKey::~CryptoKey() { - sodium_free(key_); + if (key_) { + sodium_free(key_); + } } void CryptoKey::WriteToFile(const std::string& filename) const { + assert(key_); assert(is_set_); int fd = open(filename.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0400); assert(fd != -1); @@ -111,6 +119,7 @@ void CryptoKey::WriteToFile(const std::string& filename) const { } void CryptoKey::ReadFromFile(const std::string& filename) { + assert(key_); assert(!is_set_); int fd = open(filename.c_str(), O_RDONLY); assert(fd != -1); @@ -120,25 +129,34 @@ void CryptoKey::ReadFromFile(const std::string& filename) { } const unsigned char* CryptoKey::Key() const { + assert(key_); assert(is_set_); return key_; } bool CryptoKey::IsSet() const { + assert(key_); return is_set_; } unsigned char* CryptoKey::MutableKey() { + assert(key_); assert(!is_set_); return key_; } void CryptoKey::MarkSet() { + assert(key_); assert(!is_set_); is_set_ = true; assert(!sodium_mprotect_readonly(key_)); } +void CryptoKey::Clear() { + sodium_free(key_); + key_ = nullptr; +} + SharedKey::SharedKey() : CryptoKey(crypto_secretbox_KEYBYTES) {} @@ -175,6 +193,10 @@ void PublicKey::FromString(const std::string& str) { } +PrecalcKey::PrecalcKey() + : CryptoKey(crypto_box_BEFORENMBYTES) {} + + std::ostream& CryptoBase::Log(void *obj) { char buf[64]; snprintf(buf, 64, "[%p] ", obj ? obj : this); @@ -202,9 +224,15 @@ std::unique_ptr CryptoPubConnBase::BuildSecureHandshake() { PublicKey ephemeral_public_key; CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); + if (peer_ephemeral_public_key_.IsSet()) { + CryptoUtil::PrecalculateKey(ephemeral_secret_key_, peer_ephemeral_public_key_, &ephemeral_precalc_key_); + ephemeral_secret_key_.Clear(); + peer_ephemeral_public_key_.Clear(); + } + TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE); secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key.AsString())); - return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake); + return CryptoUtil::EncodeEncrypt(precalc_key_, secure_handshake); } std::unique_ptr CryptoPubConnBase::BuildHandshake() { @@ -227,7 +255,7 @@ void CryptoPubConnBase::SendHandshake() { bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { assert(node.GetType() == TLV_TYPE_ENCRYPTED); - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(secret_key_, peer_public_key_, node)); + std::unique_ptr decrypted(CryptoUtil::DecryptDecode(precalc_key_, node)); if (!decrypted.get()) { LogFatal("Protocol error (handshake; decryption failure)"); return false; @@ -243,6 +271,11 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { return false; } peer_ephemeral_public_key_.FromString(peer_ephemeral_public_key->GetValue()); + if (ephemeral_secret_key_.IsSet()) { + CryptoUtil::PrecalculateKey(ephemeral_secret_key_, peer_ephemeral_public_key_, &ephemeral_precalc_key_); + ephemeral_secret_key_.Clear(); + peer_ephemeral_public_key_.Clear(); + } return true; } @@ -270,6 +303,7 @@ bool CryptoPubConnBase::HandleHandshake(const TLVNode& node) { } } else { peer_public_key_.FromString(peer_public_key->GetValue()); + CryptoUtil::PrecalculateKey(secret_key_, peer_public_key_, &precalc_key_); } auto encrypted = node.FindChild(TLV_TYPE_ENCRYPTED); if (!encrypted) { @@ -281,7 +315,7 @@ bool CryptoPubConnBase::HandleHandshake(const TLVNode& node) { } void CryptoPubConnBase::EncryptSend(const TLVNode& node) { - auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_secret_key_, peer_ephemeral_public_key_, node); + auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_precalc_key_, node); std::string out; encrypted->Encode(&out); bufferevent_write(bev_, out.data(), out.length()); @@ -313,7 +347,7 @@ void CryptoPubConnBase::OnReadable() { return; } - std::unique_ptr decrypted(CryptoUtil::DecryptDecode(ephemeral_secret_key_, peer_ephemeral_public_key_, *decoded)); + std::unique_ptr decrypted(CryptoUtil::DecryptDecode(ephemeral_precalc_key_, *decoded)); if (!decrypted.get()) { LogFatal("Protocol error (decryption failure)"); return; @@ -437,6 +471,7 @@ CryptoPubClient::CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const channel_bitrates_(channel_bitrates) { bev_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); peer_public_key_.FromString(server_public_key.AsString()); + CryptoUtil::PrecalculateKey(secret_key_, peer_public_key_, &precalc_key_); bufferevent_setcb(bev_, &CryptoPubClient::OnReadable_, NULL, &CryptoPubClient::OnConnectOrError_, this); bufferevent_enable(bev_, EV_READ); diff --git a/crypto.h b/crypto.h index 7cb7f02..a901931 100644 --- a/crypto.h +++ b/crypto.h @@ -20,8 +20,10 @@ class CryptoKey { unsigned char* MutableKey(); void MarkSet(); + void Clear(); + protected: - unsigned char* const key_; + unsigned char* key_; bool is_set_; const size_t key_bytes_; }; @@ -45,14 +47,20 @@ class PublicKey : public CryptoKey { void FromString(const std::string& str); }; +class PrecalcKey : public CryptoKey { + public: + PrecalcKey(); +}; + class CryptoUtil { public: static void GenKey(SharedKey* key); static void GenKeyPair(SecretKey* secret_key, PublicKey* public_key); static void DerivePublicKey(const SecretKey& secret_key, PublicKey* public_key); + static void PrecalculateKey(const SecretKey& secret_key, const PublicKey& public_key, PrecalcKey* precalc_key); - static std::unique_ptr EncodeEncrypt(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input); - static std::unique_ptr DecryptDecode(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input); + static std::unique_ptr EncodeEncrypt(const PrecalcKey& precalc_key, const TLVNode& input); + static std::unique_ptr DecryptDecode(const PrecalcKey& precalc_key, const TLVNode& input); }; class CryptoBase { @@ -91,9 +99,11 @@ class CryptoPubConnBase : public CryptoBase { const SecretKey& secret_key_; PublicKey public_key_; PublicKey peer_public_key_; + PrecalcKey precalc_key_; SecretKey ephemeral_secret_key_; PublicKey peer_ephemeral_public_key_; + PrecalcKey ephemeral_precalc_key_; }; class CryptoPubServerConnection;