Make the handshake mirrored again, for common code and to support future key rotation.

This commit is contained in:
Ian Gulliver
2015-02-08 19:50:09 +00:00
parent 7807df9575
commit 6a4a92f47a
2 changed files with 66 additions and 48 deletions

105
crypto.cc
View File

@@ -104,7 +104,7 @@ CryptoKey::~CryptoKey() {
void CryptoKey::WriteToFile(const std::string& filename) const {
assert(is_set_);
int fd = open(filename.c_str(), O_WRONLY);
int fd = open(filename.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0400);
assert(fd != -1);
assert(write(fd, key_, key_bytes_) == key_bytes_);
assert(!close(fd));
@@ -124,6 +124,10 @@ const unsigned char* CryptoKey::Key() const {
return key_;
}
bool CryptoKey::IsSet() const {
return is_set_;
}
unsigned char* CryptoKey::MutableKey() {
assert(!is_set_);
return key_;
@@ -201,6 +205,25 @@ std::unique_ptr<TLVNode> CryptoPubConnBase::BuildSecureHandshake() {
return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake);
}
std::unique_ptr<TLVNode> CryptoPubConnBase::BuildHandshake() {
auto secure_handshake = BuildSecureHandshake();
std::unique_ptr<TLVNode> handshake(new TLVNode(TLV_TYPE_HANDSHAKE));
PublicKey public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake->AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString()));
handshake->AppendChild(secure_handshake.release());
return handshake;
}
void CryptoPubConnBase::SendHandshake() {
auto handshake = BuildHandshake();
std::string out;
handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) {
assert(node.GetType() == TLV_TYPE_ENCRYPTED);
@@ -223,6 +246,40 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) {
return true;
}
bool CryptoPubConnBase::HandleHandshake(const TLVNode& node) {
if (node.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal("Protocol error (handshake; wrong message type)");
return false;
}
auto peer_public_key = node.FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_public_key) {
LogFatal("Protocol error (handshake; no public key)");
return false;
}
if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) {
LogFatal("Protocol error (handshake; wrong public key length)");
return false;
}
if (peer_public_key_.IsSet()) {
// We're the client and already know the server public key; we expect these to match.
// Eventually, we can do smarter things here to allow key rotation.
if (peer_public_key_.AsString() != peer_public_key->GetValue()) {
LogFatal("Protocol error (handshake; public key mismatch)");
return false;
}
} else {
peer_public_key_.FromString(peer_public_key->GetValue());
}
auto encrypted = node.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal("Protocol error (handshake; no encrypted portion)");
return false;
}
return HandleSecureHandshake(*encrypted);
}
void CryptoPubConnBase::EncryptSend(const TLVNode& node) {
auto encrypted = CryptoUtil::EncodeEncrypt(ephemeral_secret_key_, peer_ephemeral_public_key_, node);
std::string out;
@@ -334,28 +391,7 @@ CryptoPubServerConnection::~CryptoPubServerConnection() {
}
void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) {
if (decoded.GetType() != TLV_TYPE_HANDSHAKE) {
LogFatal("Protocol error (client handshake -- wrong message type)");
return;
}
auto peer_public_key = decoded.FindChild(TLV_TYPE_PUBLIC_KEY);
if (!peer_public_key) {
LogFatal("Protocol error (client handshake -- no public key)");
return;
}
if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) {
LogFatal("Protocol error (client handshake -- wrong public key length)");
return;
}
peer_public_key_.FromString(peer_public_key->GetValue());
auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED);
if (!encrypted) {
LogFatal("Protocol error (client handshake -- no encrypted portion)");
return;
}
if (!HandleSecureHandshake(*encrypted)) {
if (!HandleHandshake(decoded)) {
return;
}
@@ -385,13 +421,6 @@ bool CryptoPubServerConnection::OnTunnelRequest(const TLVNode& message) {
return true;
}
void CryptoPubServerConnection::SendHandshake() {
auto handshake = BuildSecureHandshake();
std::string out;
handshake->Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
void CryptoPubServerConnection::OnError_(struct bufferevent* bev, const short what, void* this__) {
auto this_ = (CryptoPubServerConnection*)this__;
this_->OnError(what);
@@ -432,7 +461,7 @@ CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address
}
void CryptoPubClient::OnHandshake(const TLVNode& decoded) {
if (!HandleSecureHandshake(decoded)) {
if (!HandleHandshake(decoded)) {
return;
}
@@ -463,20 +492,6 @@ void CryptoPubClient::OnConnect() {
SendHandshake();
}
void CryptoPubClient::SendHandshake() {
auto secure_handshake = BuildSecureHandshake();
TLVNode handshake(TLV_TYPE_HANDSHAKE);
PublicKey public_key;
CryptoUtil::DerivePublicKey(secret_key_, &public_key);
handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString()));
handshake.AppendChild(secure_handshake.release());
std::string out;
handshake.Encode(&out);
bufferevent_write(bev_, out.data(), out.length());
}
void CryptoPubClient::SendTunnelRequest() {
TLVNode tunnel_request(TLV_TYPE_TUNNEL_REQUEST);
for (auto channel_bitrate : channel_bitrates_) {