From ef5323aa3f43ad46b307456fcc1c7c8eeb629f4c Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Wed, 31 May 2023 22:40:48 -0700 Subject: [PATCH] Protype Candidate --- candidate.go | 236 +++++++++++++++++++++++++++++++++++++++++++++++++++ mac.go | 14 +++ voter.go | 34 +++++--- 3 files changed, 272 insertions(+), 12 deletions(-) create mode 100644 candidate.go create mode 100644 mac.go diff --git a/candidate.go b/candidate.go new file mode 100644 index 0000000..f709b7d --- /dev/null +++ b/candidate.go @@ -0,0 +1,236 @@ +package elect + +import ( + "crypto/hmac" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/dchest/uniuri" +) + +// TODO: Ensure promotion takes longer than demotion + +type Candidate struct { + C <-chan CandidateState + + numVoters int + signingKey []byte + stop chan<- bool + done <-chan bool + resp voteResponse + c chan<- CandidateState + + votes map[string]*vote + state CandidateState + mu sync.Mutex +} + +type CandidateState string + +var ( + StateLeader CandidateState = "LEADER" + StateNotLeader CandidateState = "NOT_LEADER" +) + +func NewCandidate(numVoters int, signingKey string) *Candidate { + stop := make(chan bool) + done := make(chan bool) + change := make(chan CandidateState, 100) + + c := &Candidate{ + C: change, + numVoters: numVoters, + signingKey: []byte(signingKey), + votes: map[string]*vote{}, + stop: stop, + done: done, + c: change, + resp: voteResponse{ + CandidateID: uniuri.New(), + }, + } + + go c.loop(stop, done) + + return c +} + +func (c *Candidate) Stop() { + close(c.stop) + <-c.done +} + +func (c *Candidate) State() CandidateState { + c.mu.Lock() + defer c.mu.Unlock() + + return c.state +} + +func (c *Candidate) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error( + w, + fmt.Sprintf("method %s not supported", r.Method), + http.StatusMethodNotAllowed, + ) + + return + } + + sig := r.Header.Get("Signature") + if sig == "" { + http.Error( + w, + "missing Signature header", + http.StatusBadRequest, + ) + + return + } + + if r.Header.Get("Content-Type") != "application/json" { + http.Error( + w, + fmt.Sprintf("Content-Type %s not supported", r.Header.Get("Content-Type")), + http.StatusUnsupportedMediaType, + ) + + return + } + + js, err := io.ReadAll(r.Body) + if err != nil { + http.Error( + w, + fmt.Sprintf("can't read request body: %s", err), + http.StatusBadRequest, + ) + + return + } + + if !hmac.Equal([]byte(sig), []byte(mac(js, c.signingKey))) { + http.Error( + w, + "Signature verification failed", + http.StatusBadRequest, + ) + + return + } + + v := &vote{} + + err = json.Unmarshal(js, v) + if err != nil { + http.Error( + w, + fmt.Sprintf("can't parse request body: %s", err), + http.StatusBadRequest, + ) + + return + } + + enc := json.NewEncoder(w) + + err = enc.Encode(c.resp) + if err != nil { + http.Error( + w, + fmt.Sprintf("can't write response: %s", err), + http.StatusInternalServerError, + ) + + return + } + + c.vote(v) +} + +func (c *Candidate) vote(v *vote) { + v.received = time.Now() + + { + c.mu.Lock() + c.votes[v.VoterID] = v + c.mu.Unlock() + } + + c.elect() +} + +func (c *Candidate) voteIfNo(v *vote) { + if v.LastSeenCandidateID == c.resp.CandidateID { + return + } + + c.vote(v) +} + +func (c *Candidate) elect() { + no := 0 + yes := 0 + + cutoff := time.Now().Add(-10 * time.Second) + + c.mu.Lock() ///////////// + + for key, vote := range c.votes { + if vote.received.Before(cutoff) { + delete(c.votes, key) + continue + } + + if vote.LastSeenCandidateID != c.resp.CandidateID { + no++ + } + + if vote.NumPollsSinceChange < 10 { + continue + } + + yes++ + } + + c.mu.Unlock() //////////// + + if no == 0 && yes > c.numVoters/2 { + c.update(StateLeader) + } else { + c.update(StateNotLeader) + } +} + +func (c *Candidate) update(state CandidateState) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.state == state { + return + } + + c.state = state + c.c <- state +} + +func (c *Candidate) loop(stop <-chan bool, done chan<- bool) { + t := time.NewTicker(1 * time.Second) + defer t.Stop() + defer close(done) + + for { + select { + case <-stop: + return + + case <-t.C: + c.elect() + } + } +} diff --git a/mac.go b/mac.go new file mode 100644 index 0000000..47699be --- /dev/null +++ b/mac.go @@ -0,0 +1,14 @@ +package elect + +import ( + "crypto/hmac" + "crypto/sha256" + "fmt" +) + +func mac(payload, signingKey []byte) string { + gen := hmac.New(sha256.New, signingKey) + gen.Write(payload) + + return fmt.Sprintf("%x", gen.Sum(nil)) +} diff --git a/voter.go b/voter.go index 76c25a9..dab3ccc 100644 --- a/voter.go +++ b/voter.go @@ -1,10 +1,7 @@ package elect import ( - "crypto/hmac" - "crypto/sha256" "encoding/json" - "fmt" "log" "time" @@ -19,12 +16,16 @@ type Voter struct { update chan<- time.Duration done <-chan bool vote vote + candidates []*Candidate } type vote struct { VoterID string `json:"voterID"` LastSeenCandidateID string `json:"lastSeenCandidateID"` NumPollsSinceChange int `json:"numPollsSinceChange"` + + // Used internally by Candidate + received time.Time } type voteResponse struct { @@ -57,6 +58,10 @@ func (v *Voter) Stop() { <-v.done } +func (v *Voter) AddCandidate(c *Candidate) { + v.candidates = append(v.candidates, c) +} + func (v *Voter) loop(update <-chan time.Duration, done chan<- bool) { t := time.NewTicker(5 * time.Second) defer t.Stop() @@ -96,26 +101,31 @@ func (v *Voter) poll(update <-chan time.Duration, t *time.Ticker) bool { } func (v *Voter) sendVote() { + for _, c := range v.candidates { + c.voteIfNo(&v.vote) + } + js := lo.Must(json.Marshal(v.vote)) - - genMAC := hmac.New(sha256.New, v.signingKey) - genMAC.Write(js) - mac := fmt.Sprintf("%x", genMAC.Sum(nil)) - vr := &voteResponse{} resp, err := v.client.R(). - SetHeader("Signature", mac). + SetHeader("Signature", mac(js, v.signingKey)). SetBody(js). SetResult(vr). - Post("_vote") + Post("") if err != nil { - log.Printf("_vote response: %s", err) + log.Printf("vote response: %s", err) + + v.vote.NumPollsSinceChange = 0 + return } if resp.IsError() { - log.Printf("_vote response: [%d] %s\n%s", resp.StatusCode(), resp.Status(), resp.String()) + log.Printf("vote response: [%d] %s\n%s", resp.StatusCode(), resp.Status(), resp.String()) + + v.vote.NumPollsSinceChange = 0 + return }