From 8c8133ff5286c2b7bd8ca475fffceb966bfaf230 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Sat, 3 Jun 2023 21:12:48 -0700 Subject: [PATCH] Sign responses as well --- candidate.go | 9 +++++---- voter.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/candidate.go b/candidate.go index c1cbe7d..52f255c 100644 --- a/candidate.go +++ b/candidate.go @@ -10,6 +10,7 @@ import ( "time" "github.com/dchest/uniuri" + "github.com/samber/lo" ) // TODO: Ensure promotion takes longer than demotion @@ -141,12 +142,12 @@ func (c *Candidate) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + js = lo.Must(json.Marshal(c.resp)) + w.Header().Set("Content-Type", "application/json") - // TODO: Sign responses + w.Header().Set("Signature", mac(js, c.signingKey)) - enc := json.NewEncoder(w) - - err = enc.Encode(c.resp) + _, err = w.Write(js) if err != nil { http.Error( w, diff --git a/voter.go b/voter.go index f361709..3f995fb 100644 --- a/voter.go +++ b/voter.go @@ -1,6 +1,7 @@ package elect import ( + "crypto/hmac" "encoding/json" "log" "time" @@ -109,13 +110,12 @@ func (v *Voter) sendVote() { } js := lo.Must(json.Marshal(v.vote)) - vr := &voteResponse{} resp, err := v.client.R(). SetHeader("Signature", mac(js, v.signingKey)). SetHeader("Content-Type", "application/json"). + SetHeader("Accept", "application/json"). SetBody(js). - SetResult(vr). Post("") if err != nil { log.Printf("vote response: %s", err) @@ -126,13 +126,32 @@ func (v *Voter) sendVote() { } if resp.IsError() { - log.Printf("vote response: [%d] %s\n%s", resp.StatusCode(), resp.Status(), resp.String()) + v.log("response: [%d] %s\n%s", resp.StatusCode(), resp.Status(), resp.String()) v.vote.NumPollsSinceChange = 0 return } + sig := resp.Header().Get("Signature") + if sig == "" { + v.log("missing Signature response header") + return + } + + if !hmac.Equal([]byte(sig), []byte(mac(resp.Body(), v.signingKey))) { + v.log("invalid Signature response header") + return + } + + vr := &voteResponse{} + + err = json.Unmarshal(resp.Body(), vr) + if err != nil { + v.log("invalid response: %s", resp.String()) + return + } + if vr.CandidateID == v.vote.LastSeenCandidateID { v.vote.NumPollsSinceChange++ } else { @@ -140,3 +159,7 @@ func (v *Voter) sendVote() { v.vote.NumPollsSinceChange = 0 } } + +func (v *Voter) log(format string, args ...any) { + log.Printf("[voter] "+format, args...) +}