diff --git a/auth-client.cc b/auth-client.cc index e015de2..24e7f03 100644 --- a/auth-client.cc +++ b/auth-client.cc @@ -46,11 +46,11 @@ int main(int argc, char *argv[]) { } } - std::string secret_key; - CryptoUtil::ReadKeyFromFile(secret_key_filename, &secret_key); + SecretKey secret_key; + secret_key.ReadFromFile(secret_key_filename); - std::string server_public_key; - CryptoUtil::ReadKeyFromFile(server_public_key_filename, &server_public_key); + PublicKey server_public_key; + server_public_key.ReadFromFile(server_public_key_filename); auto client = CryptoPubClient::FromHostname(server_address, server_port, secret_key, server_public_key, channel_bitrates); client->Loop(); diff --git a/auth-server.cc b/auth-server.cc index 31528b4..722793f 100644 --- a/auth-server.cc +++ b/auth-server.cc @@ -21,8 +21,10 @@ int main(int argc, char *argv[]) { } } - std::string secret_key; - CryptoUtil::ReadKeyFromFile(secret_key_filename, &secret_key); + sodium_init(); + + SecretKey secret_key; + secret_key.ReadFromFile(secret_key_filename); CryptoPubServer server(secret_key); server.Loop(); diff --git a/crypto.cc b/crypto.cc index b3ada92..0c99667 100644 --- a/crypto.cc +++ b/crypto.cc @@ -1,7 +1,11 @@ #include +#include #include +#include +#include #include #include +#include #include #include @@ -12,6 +16,7 @@ #include #include #include +#include #include "crypto.h" @@ -28,54 +33,23 @@ #define TLV_TYPE_CHANNEL 0x8004 -std::string CryptoUtil::BinToHex(const std::string& bin) { - static const char hex[] = "0123456789abcdef"; - std::string ret; - ret.reserve(bin.length() * 2); - for (int i = 0; i < bin.length(); i++) { - ret.push_back(hex[(bin[i] & 0xf0) >> 4]); - ret.push_back(hex[bin[i] & 0x0f]); - } - return ret; +void CryptoUtil::GenKey(SharedKey* key) { + randombytes_buf(key->MutableKey(), crypto_secretbox_KEYBYTES); + key->MarkSet(); } -void CryptoUtil::GenKey(std::string* key) { - unsigned char buf[crypto_secretbox_KEYBYTES]; - randombytes_buf(buf, crypto_secretbox_KEYBYTES); - key->assign((char*)buf, crypto_secretbox_KEYBYTES); +void CryptoUtil::GenKeyPair(SecretKey* secret_key, PublicKey* public_key) { + assert(!crypto_box_keypair(public_key->MutableKey(), secret_key->MutableKey())); + public_key->MarkSet(); + secret_key->MarkSet(); } -void CryptoUtil::GenKeyPair(std::string* secret_key, std::string* public_key) { - unsigned char public_key_buf[crypto_box_PUBLICKEYBYTES]; - unsigned char secret_key_buf[crypto_box_PUBLICKEYBYTES]; - assert(crypto_box_keypair(public_key_buf, secret_key_buf) == 0); - public_key->assign((char*)public_key_buf, crypto_box_PUBLICKEYBYTES); - secret_key->assign((char*)secret_key_buf, crypto_box_SECRETKEYBYTES); +void CryptoUtil::DerivePublicKey(const SecretKey& secret_key, PublicKey* public_key) { + assert(!crypto_scalarmult_base(public_key->MutableKey(), secret_key.Key())); + public_key->MarkSet(); } -void CryptoUtil::DerivePublicKey(const std::string& secret_key, std::string* public_key) { - assert(secret_key.length() == crypto_box_SECRETKEYBYTES); - unsigned char buf[crypto_box_PUBLICKEYBYTES]; - assert(!crypto_scalarmult_base(buf, (const unsigned char*)secret_key.data())); - public_key->assign((char*)buf, crypto_box_PUBLICKEYBYTES); -} - -void CryptoUtil::ReadKeyFromFile(const std::string& filename, std::string* key) { - std::fstream key_file(filename, std::fstream::in); - assert(!key_file.fail()); - key_file >> *key; -} - -void CryptoUtil::WriteKeyToFile(const std::string& filename, const std::string& key) { - std::fstream key_file(filename, std::fstream::out); - assert(!key_file.fail()); - key_file << key; -} - -std::unique_ptr CryptoUtil::EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { - assert(secret_key.length() == crypto_box_SECRETKEYBYTES); - assert(public_key.length() == crypto_box_PUBLICKEYBYTES); - +std::unique_ptr CryptoUtil::EncodeEncrypt(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input) { std::string encoded; input.Encode(&encoded); @@ -85,7 +59,7 @@ std::unique_ptr CryptoUtil::EncodeEncrypt(const std::string& secret_key randombytes_buf(nonce, crypto_box_NONCEBYTES); unsigned char output[encrypted_bytes]; - assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data())); + assert(!crypto_box_easy(output, (const unsigned char*)encoded.data(), encoded.length(), nonce, public_key.Key(), secret_key.Key())); std::unique_ptr encrypted(new TLVNode(TLV_TYPE_ENCRYPTED)); encrypted->AppendChild(new TLVNode(TLV_TYPE_NONCE, std::string((char*)nonce, crypto_box_NONCEBYTES))); @@ -94,9 +68,7 @@ std::unique_ptr CryptoUtil::EncodeEncrypt(const std::string& secret_key return encrypted; } -std::unique_ptr CryptoUtil::DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input) { - assert(secret_key.length() == crypto_box_SECRETKEYBYTES); - assert(public_key.length() == crypto_box_PUBLICKEYBYTES); +std::unique_ptr CryptoUtil::DecryptDecode(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input) { assert(input.GetType() == TLV_TYPE_ENCRYPTED); auto nonce = input.FindChild(TLV_TYPE_NONCE); @@ -111,13 +83,94 @@ std::unique_ptr CryptoUtil::DecryptDecode(const std::string& secret_key size_t decrypted_bytes = encrypted->GetValue().length() - crypto_box_MACBYTES; unsigned char output[decrypted_bytes]; - if (crypto_box_open_easy(output, (const unsigned char*)encrypted->GetValue().data(), encrypted->GetValue().length(), (const unsigned char*)nonce->GetValue().data(), (const unsigned char*)public_key.data(), (const unsigned char*)secret_key.data())) { + if (crypto_box_open_easy(output, (const unsigned char*)encrypted->GetValue().data(), encrypted->GetValue().length(), (const unsigned char*)nonce->GetValue().data(), public_key.Key(), secret_key.Key())) { return nullptr; } return TLVNode::Decode(std::string((char*)output, decrypted_bytes)); } + +CryptoKey::CryptoKey(const size_t key_bytes) + : key_bytes_(key_bytes), + is_set_(false), + key_((unsigned char*)sodium_malloc(key_bytes)) { + assert(key_); +} + +CryptoKey::~CryptoKey() { + sodium_free(key_); +} + +void CryptoKey::WriteToFile(const std::string& filename) const { + assert(is_set_); + int fd = open(filename.c_str(), O_WRONLY); + assert(fd != -1); + assert(write(fd, key_, key_bytes_) == key_bytes_); + assert(!close(fd)); +} + +void CryptoKey::ReadFromFile(const std::string& filename) { + assert(!is_set_); + int fd = open(filename.c_str(), O_RDONLY); + assert(fd != -1); + assert(read(fd, key_, key_bytes_ + 1) == key_bytes_); + assert(!close(fd)); + MarkSet(); +} + +const unsigned char* CryptoKey::Key() const { + assert(is_set_); + return key_; +} + +unsigned char* CryptoKey::MutableKey() { + assert(!is_set_); + return key_; +} + +void CryptoKey::MarkSet() { + assert(!is_set_); + is_set_ = true; + assert(!sodium_mprotect_readonly(key_)); +} + + +SharedKey::SharedKey() + : CryptoKey(crypto_secretbox_KEYBYTES) {} + + +SecretKey::SecretKey() + : CryptoKey(crypto_box_SECRETKEYBYTES) {} + + +PublicKey::PublicKey() + : CryptoKey(crypto_box_PUBLICKEYBYTES) {} + +std::string PublicKey::AsString() const { + assert(is_set_); + return std::string((char*)key_, key_bytes_); +} + +std::string PublicKey::ToHex() const { + static const char hex[] = "0123456789abcdef"; + std::string ret; + ret.reserve(key_bytes_ * 2); + for (int i = 0; i < key_bytes_; i++) { + ret.push_back(hex[(key_[i] & 0xf0) >> 4]); + ret.push_back(hex[key_[i] & 0x0f]); + } + return ret; +} + +void PublicKey::FromString(const std::string& str) { + assert(!is_set_); + assert(str.length() == key_bytes_); + memcpy(key_, str.data(), key_bytes_); + MarkSet(); +} + + std::ostream& CryptoBase::Log(void *obj) { char buf[64]; snprintf(buf, 64, "[%p] ", obj ? obj : this); @@ -125,7 +178,7 @@ std::ostream& CryptoBase::Log(void *obj) { } -CryptoPubConnBase::CryptoPubConnBase(const std::string& secret_key) +CryptoPubConnBase::CryptoPubConnBase(const SecretKey& secret_key) : secret_key_(secret_key), state_(AWAITING_HANDSHAKE) {} @@ -140,11 +193,11 @@ void CryptoPubConnBase::LogFatal(const std::string& msg, void *obj) { } std::unique_ptr CryptoPubConnBase::BuildSecureHandshake() { - std::string ephemeral_public_key; + PublicKey ephemeral_public_key; CryptoUtil::GenKeyPair(&ephemeral_secret_key_, &ephemeral_public_key); TLVNode secure_handshake(TLV_TYPE_HANDSHAKE_SECURE); - secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key)); + secure_handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, ephemeral_public_key.AsString())); return CryptoUtil::EncodeEncrypt(secret_key_, peer_public_key_, secure_handshake); } @@ -162,11 +215,11 @@ bool CryptoPubConnBase::HandleSecureHandshake(const TLVNode& node) { LogFatal("Protocol error (handshake; no ephemeral public key)"); return false; } - peer_ephemeral_public_key_ = peer_ephemeral_public_key->GetValue(); - if (peer_ephemeral_public_key_.length() != crypto_box_PUBLICKEYBYTES) { + if (peer_ephemeral_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) { LogFatal("Protocol error (handshake; wrong ephemeral public key length)"); return false; } + peer_ephemeral_public_key_.FromString(peer_ephemeral_public_key->GetValue()); return true; } @@ -216,14 +269,12 @@ void CryptoPubConnBase::OnReadable() { } -CryptoPubServer::CryptoPubServer(const std::string& secret_key) +CryptoPubServer::CryptoPubServer(const SecretKey& secret_key) : secret_key_(secret_key), event_base_(event_base_new()) { auto signal_event = evsignal_new(event_base_, SIGINT, &CryptoPubServer::Shutdown_, this); event_add(signal_event, NULL); - assert(secret_key_.length() == crypto_box_SECRETKEYBYTES); - struct sockaddr_in6 server_addr = {0}; server_addr.sin6_family = AF_INET6; server_addr.sin6_addr = in6addr_any; @@ -272,7 +323,7 @@ void CryptoPubServer::Shutdown() { } -CryptoPubServerConnection::CryptoPubServerConnection(struct bufferevent* bev, const std::string& secret_key) +CryptoPubServerConnection::CryptoPubServerConnection(struct bufferevent* bev, const SecretKey& secret_key) : CryptoPubConnBase(secret_key) { bev_ = bev; } @@ -292,11 +343,11 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { LogFatal("Protocol error (client handshake -- no public key)"); return; } - peer_public_key_ = peer_public_key->GetValue(); - if (peer_public_key_.length() != crypto_box_PUBLICKEYBYTES) { + if (peer_public_key->GetValue().length() != crypto_box_PUBLICKEYBYTES) { LogFatal("Protocol error (client handshake -- wrong public key length)"); return; } + peer_public_key_.FromString(peer_public_key->GetValue()); auto encrypted = decoded.FindChild(TLV_TYPE_ENCRYPTED); if (!encrypted) { LogFatal("Protocol error (client handshake -- no encrypted portion)"); @@ -310,7 +361,7 @@ void CryptoPubServerConnection::OnHandshake(const TLVNode& decoded) { SendHandshake(); this->state_ = READY; - Log() << "Handshake successful (client ID: " << CryptoUtil::BinToHex(peer_public_key_) << ")" << std::endl; + Log() << "Handshake successful (client ID: " << peer_public_key_.ToHex() << ")" << std::endl; } bool CryptoPubServerConnection::OnMessage(const TLVNode& message) { @@ -350,14 +401,12 @@ void CryptoPubServerConnection::OnError(const short what) { } -CryptoPubClient::CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const std::string& secret_key, const std::string& server_public_key, const std::list& channel_bitrates) +CryptoPubClient::CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const SecretKey& secret_key, const PublicKey& server_public_key, const std::list& channel_bitrates) : CryptoPubConnBase(secret_key), event_base_(event_base_new()), channel_bitrates_(channel_bitrates) { bev_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE); - peer_public_key_ = server_public_key; - assert(secret_key_.length() == crypto_box_SECRETKEYBYTES); - assert(peer_public_key_.length() == crypto_box_PUBLICKEYBYTES); + peer_public_key_.FromString(server_public_key.AsString()); bufferevent_setcb(bev_, &CryptoPubClient::OnReadable_, NULL, &CryptoPubClient::OnConnectOrError_, this); bufferevent_enable(bev_, EV_READ); @@ -369,7 +418,7 @@ CryptoPubClient::~CryptoPubClient() { event_base_free(event_base_); } -CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address, const std::string& server_port, const std::string& secret_key, const std::string& server_public_key, const std::list& channel_bitrates) { +CryptoPubClient* CryptoPubClient::FromHostname(const std::string& server_address, const std::string& server_port, const SecretKey& secret_key, const PublicKey& server_public_key, const std::list& channel_bitrates) { struct addrinfo* res; int gai_ret = getaddrinfo(server_address.c_str(), server_port.c_str(), NULL, &res); if (gai_ret) { @@ -417,9 +466,9 @@ void CryptoPubClient::SendHandshake() { auto secure_handshake = BuildSecureHandshake(); TLVNode handshake(TLV_TYPE_HANDSHAKE); - std::string public_key; + PublicKey public_key; CryptoUtil::DerivePublicKey(secret_key_, &public_key); - handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key)); + handshake.AppendChild(new TLVNode(TLV_TYPE_PUBLIC_KEY, public_key.AsString())); handshake.AppendChild(secure_handshake.release()); std::string out; diff --git a/crypto.h b/crypto.h index 5083643..a88d38d 100644 --- a/crypto.h +++ b/crypto.h @@ -1,23 +1,57 @@ #include #include #include +#include #include #include "tlv.h" +class CryptoKey { + public: + CryptoKey(const size_t key_bytes); + ~CryptoKey(); + void ReadFromFile(const std::string& filename); + void WriteToFile(const std::string& filename) const; + + const unsigned char* Key() const; + + unsigned char* MutableKey(); + void MarkSet(); + + protected: + unsigned char* const key_; + bool is_set_; + const size_t key_bytes_; +}; + +class SharedKey : public CryptoKey { + public: + SharedKey(); +}; + +class SecretKey : public CryptoKey { + public: + SecretKey(); +}; + +class PublicKey : public CryptoKey { + public: + PublicKey(); + + std::string AsString() const; + std::string ToHex() const; + void FromString(const std::string& str); +}; + class CryptoUtil { public: - static std::string BinToHex(const std::string& bin); + static void GenKey(SharedKey* key); + static void GenKeyPair(SecretKey* secret_key, PublicKey* public_key); + static void DerivePublicKey(const SecretKey& secret_key, PublicKey* public_key); - static void GenKey(std::string* key); - static void GenKeyPair(std::string* secret_key, std::string* public_key); - static void DerivePublicKey(const std::string& secret_key, std::string* public_key); - static void ReadKeyFromFile(const std::string& filename, std::string* key); - static void WriteKeyToFile(const std::string& filename, const std::string& key); - - static std::unique_ptr EncodeEncrypt(const std::string& secret_key, const std::string& public_key, const TLVNode& input); - static std::unique_ptr DecryptDecode(const std::string& secret_key, const std::string& public_key, const TLVNode& input); + static std::unique_ptr EncodeEncrypt(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input); + static std::unique_ptr DecryptDecode(const SecretKey& secret_key, const PublicKey& public_key, const TLVNode& input); }; class CryptoBase { @@ -27,7 +61,7 @@ class CryptoBase { class CryptoPubConnBase : public CryptoBase { protected: - CryptoPubConnBase(const std::string& secret_key); + CryptoPubConnBase(const SecretKey& secret_key); virtual ~CryptoPubConnBase(); void LogFatal(const std::string& msg, void *obj=nullptr); @@ -48,17 +82,17 @@ class CryptoPubConnBase : public CryptoBase { struct bufferevent* bev_; - const std::string secret_key_; - std::string peer_public_key_; - std::string ephemeral_secret_key_; - std::string peer_ephemeral_public_key_; + const SecretKey& secret_key_; + PublicKey peer_public_key_; + SecretKey ephemeral_secret_key_; + PublicKey peer_ephemeral_public_key_; }; class CryptoPubServerConnection; class CryptoPubServer : public CryptoBase { public: - CryptoPubServer(const std::string& secret_key); + CryptoPubServer(const SecretKey& secret_key); ~CryptoPubServer(); void Loop(); void Shutdown(); @@ -72,12 +106,12 @@ class CryptoPubServer : public CryptoBase { struct event_base* event_base_; struct evconnlistener* listener_; - const std::string secret_key_; + const SecretKey& secret_key_; }; class CryptoPubServerConnection : public CryptoPubConnBase { public: - CryptoPubServerConnection(struct bufferevent* bev, const std::string& secret_key); + CryptoPubServerConnection(struct bufferevent* bev, const SecretKey& secret_key); ~CryptoPubServerConnection(); private: @@ -95,10 +129,10 @@ class CryptoPubServerConnection : public CryptoPubConnBase { class CryptoPubClient : public CryptoPubConnBase { public: - CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const std::string& secret_key, const std::string& server_public_key, const std::list& channel_bitrates); + CryptoPubClient(struct sockaddr* addr, socklen_t addrlen, const SecretKey& secret_key, const PublicKey& server_public_key, const std::list& channel_bitrates); ~CryptoPubClient(); - static CryptoPubClient* FromHostname(const std::string& server_address, const std::string& server_port, const std::string& secret_key, const std::string& server_public_key, const std::list& channel_bitrates); + static CryptoPubClient* FromHostname(const std::string& server_address, const std::string& server_port, const SecretKey& secret_key, const PublicKey& server_public_key, const std::list& channel_bitrates); void Loop(); diff --git a/gen-key.cc b/gen-key.cc index 5dc37c5..ca4ceaa 100644 --- a/gen-key.cc +++ b/gen-key.cc @@ -10,10 +10,10 @@ int main(int argc, char *argv[]) { return 1; } - std::string key; + SharedKey key; CryptoUtil::GenKey(&key); - CryptoUtil::WriteKeyToFile(argv[1], key); + key.WriteToFile(argv[1]); return 0; } diff --git a/gen-keypair.cc b/gen-keypair.cc index fd5e19e..4169ca6 100644 --- a/gen-keypair.cc +++ b/gen-keypair.cc @@ -10,11 +10,12 @@ int main(int argc, char *argv[]) { return 1; } - std::string secret_key, public_key; + SecretKey secret_key; + PublicKey public_key; CryptoUtil::GenKeyPair(&secret_key, &public_key); - CryptoUtil::WriteKeyToFile(argv[1], secret_key); - CryptoUtil::WriteKeyToFile(argv[2], public_key); + secret_key.WriteToFile(argv[1]); + public_key.WriteToFile(argv[2]); return 0; }