In-memory cleanup

This commit is contained in:
Ian Gulliver
2023-04-24 21:55:52 +00:00
parent d4766011e8
commit c0060eb030

View File

@@ -1,8 +1,8 @@
package potency package potency
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -25,14 +25,14 @@ type Potency struct {
} }
type savedResult struct { type savedResult struct {
Method string `json:"method"` method string
URL string `json:"url"` url string
RequestHeader http.Header `json:"requestHeader"` requestHeader http.Header
Sha256 string `json:"sha256"` sha256 []byte
StatusCode int `json:"statusCode"` statusCode int
ResponseHeader http.Header `json:"responseHeader"` responseHeader http.Header
ResponseBody []byte `json:"responseBody"` responseBody []byte
} }
var ( var (
@@ -81,16 +81,16 @@ func (p *Potency) serveHTTP(w http.ResponseWriter, r *http.Request, val string)
saved := p.read(key) saved := p.read(key)
if saved != nil { if saved != nil {
if r.Method != saved.Method { if r.Method != saved.method {
return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.Method, ErrMethodMismatch) return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.Method, ErrMethodMismatch)
} }
if r.URL.String() != saved.URL { if r.URL.String() != saved.url {
return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.URL.String(), ErrURLMismatch) return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.URL.String(), ErrURLMismatch)
} }
for _, h := range criticalHeaders { for _, h := range criticalHeaders {
if saved.RequestHeader.Get(h) != r.Header.Get(h) { if saved.requestHeader.Get(h) != r.Header.Get(h) {
return jsrest.Errorf(jsrest.ErrBadRequest, "%s: %s (%w)", h, r.Header.Get(h), ErrHeaderMismatch) return jsrest.Errorf(jsrest.ErrBadRequest, "%s: %s (%w)", h, r.Header.Get(h), ErrHeaderMismatch)
} }
} }
@@ -102,17 +102,17 @@ func (p *Potency) serveHTTP(w http.ResponseWriter, r *http.Request, val string)
return jsrest.Errorf(jsrest.ErrBadRequest, "hash request body failed (%w)", err) return jsrest.Errorf(jsrest.ErrBadRequest, "hash request body failed (%w)", err)
} }
hexed := hex.EncodeToString(h.Sum(nil)) sha256 := h.Sum(nil)
if hexed != saved.Sha256 { if !bytes.Equal(sha256, saved.sha256) {
return jsrest.Errorf(jsrest.ErrBadRequest, "%s vs %s (%w)", hexed, saved.Sha256, ErrBodyMismatch) return jsrest.Errorf(jsrest.ErrBadRequest, "%s vs %s (%w)", sha256, saved.sha256, ErrBodyMismatch)
} }
for key, vals := range saved.ResponseHeader { for key, vals := range saved.responseHeader {
w.Header().Set(key, vals[0]) w.Header().Set(key, vals[0])
} }
w.WriteHeader(saved.StatusCode) w.WriteHeader(saved.statusCode)
_, _ = w.Write(saved.ResponseBody) _, _ = w.Write(saved.responseBody)
return nil return nil
} }
@@ -139,14 +139,14 @@ func (p *Potency) serveHTTP(w http.ResponseWriter, r *http.Request, val string)
p.handler.ServeHTTP(w, r) p.handler.ServeHTTP(w, r)
save := &savedResult{ save := &savedResult{
Method: r.Method, method: r.Method,
URL: r.URL.String(), url: r.URL.String(),
RequestHeader: requestHeader, requestHeader: requestHeader,
Sha256: hex.EncodeToString(bi.sha256.Sum(nil)), sha256: bi.sha256.Sum(nil),
StatusCode: rwi.statusCode, statusCode: rwi.statusCode,
ResponseHeader: rwi.Header(), responseHeader: rwi.Header(),
ResponseBody: rwi.buf.Bytes(), responseBody: rwi.buf.Bytes(),
} }
p.write(key, save) p.write(key, save)