From cb9d3aef6203ff4728f8e0150ece25318e6f6f69 Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Thu, 20 Apr 2023 17:54:42 +0000 Subject: [PATCH] Initial commit --- .gitignore | 2 + .golangci.yaml | 38 ++++++++ bodyintercept.go | 30 ++++++ go.mod | 20 ++++ go.sum | 40 ++++++++ justfile | 18 ++++ pkg_test.go | 11 +++ potency.go | 189 +++++++++++++++++++++++++++++++++++++ potency_test.go | 183 +++++++++++++++++++++++++++++++++++ responsewriterintercept.go | 34 +++++++ 10 files changed, 565 insertions(+) create mode 100644 .gitignore create mode 100644 .golangci.yaml create mode 100644 bodyintercept.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 justfile create mode 100644 pkg_test.go create mode 100644 potency.go create mode 100644 potency_test.go create mode 100644 responsewriterintercept.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6754c7d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +cover.out +cover.html diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..9aa3e2e --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,38 @@ +linters: + enable-all: true + disable: + # re-enable when working + - rowserrcheck + - wastedassign + # maybe enable these + - wrapcheck + # leave these disabled + - cyclop + - deadcode + - dupl + - exhaustivestruct + - exhaustruct + - forbidigo + - forcetypeassert + - funlen + - gochecknoglobals + - gocognit + - goconst + - godox + - golint + - gomnd + - ifshort + - interfacer + - lll + - maintidx + - maligned + - nilnil + - nestif + - nlreturn + - nolintlint + - nosnakecase + - scopelint + - structcheck + - thelper + - varcheck + - varnamelen diff --git a/bodyintercept.go b/bodyintercept.go new file mode 100644 index 0000000..5defae0 --- /dev/null +++ b/bodyintercept.go @@ -0,0 +1,30 @@ +package potency + +import ( + "crypto/sha256" + "hash" + "io" +) + +type bodyIntercept struct { + source io.ReadCloser + sha256 hash.Hash +} + +func newBodyIntercept(source io.ReadCloser) *bodyIntercept { + return &bodyIntercept{ + source: source, + sha256: sha256.New(), + } +} + +func (bi *bodyIntercept) Read(p []byte) (int, error) { + numBytes, err := bi.source.Read(p) + bi.sha256.Write(p[:numBytes]) + + return numBytes, err +} + +func (bi *bodyIntercept) Close() error { + return bi.source.Close() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8a8eb9c --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module github.com/gopatchy/potency + +go 1.19 + +require ( + github.com/dchest/uniuri v1.2.0 + github.com/go-resty/resty/v2 v2.7.0 + github.com/gopatchy/jsrest v0.0.0-20230420161234-12a6d6da8b7f + github.com/stretchr/testify v1.8.2 + go.uber.org/goleak v1.2.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gopatchy/metadata v0.0.0-20230420053349-25837551c11d // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/vfaronov/httpheader v0.1.0 // indirect + golang.org/x/net v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..75a4118 --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= +github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= +github.com/gopatchy/jsrest v0.0.0-20230420161234-12a6d6da8b7f h1:1uGPJm9K0Fro1UEcZpuK6FNPU/U1XX3aS3x0/PdFS40= +github.com/gopatchy/jsrest v0.0.0-20230420161234-12a6d6da8b7f/go.mod h1:Ryi8LRBLFDhQsMQHuh+6VL7HcFWjBXOEiOy9Ip/Q+Ps= +github.com/gopatchy/metadata v0.0.0-20230420053349-25837551c11d h1:chunoM47vkWSanIvLx4uRSkLMG6chDZOy09L2tt/bv8= +github.com/gopatchy/metadata v0.0.0-20230420053349-25837551c11d/go.mod h1:VgD33raUShjDePCDBo55aj+eSXFtUEpMzs+Ie39g2zo= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4 h1:Ha8xCaq6ln1a+R91Km45Oq6lPXj2Mla6CRJYcuV2h1w= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/vfaronov/httpheader v0.1.0 h1:VdzetvOKRoQVHjSrXcIOwCV6JG5BCAW9rjbVbFPBmb0= +github.com/vfaronov/httpheader v0.1.0/go.mod h1:ZBxgbYu6nbN5V9Ptd1yYUUan0voD0O8nZLXHyxLgoLE= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/justfile b/justfile new file mode 100644 index 0000000..d607f93 --- /dev/null +++ b/justfile @@ -0,0 +1,18 @@ +go := env_var_or_default('GOCMD', 'go') + +default: tidy test + +tidy: + {{go}} mod tidy + goimports -l -w . + gofumpt -l -w . + {{go}} fmt ./... + +test: + {{go}} vet ./... + golangci-lint run ./... + {{go}} test -race -coverprofile=cover.out -timeout=60s -parallel=10 ./... + {{go}} tool cover -html=cover.out -o=cover.html + +todo: + -git grep -e TODO --and --not -e ignoretodo diff --git a/pkg_test.go b/pkg_test.go new file mode 100644 index 0000000..54e9aba --- /dev/null +++ b/pkg_test.go @@ -0,0 +1,11 @@ +package potency_test + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/potency.go b/potency.go new file mode 100644 index 0000000..8704369 --- /dev/null +++ b/potency.go @@ -0,0 +1,189 @@ +package potency + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "github.com/gopatchy/jsrest" +) + +type Potency struct { + handler http.Handler + + // TODO: Expire based on time; probably use age-based linked list and delete at write time + cache map[string]*savedResult + cacheMu sync.RWMutex + + inProgress map[string]bool + inProgressMu sync.Mutex +} + +type savedResult struct { + Method string `json:"method"` + URL string `json:"url"` + RequestHeader http.Header `json:"requestHeader"` + Sha256 string `json:"sha256"` + + StatusCode int `json:"statusCode"` + ResponseHeader http.Header `json:"responseHeader"` + ResponseBody []byte `json:"responseBody"` +} + +var ( + ErrConflict = errors.New("conflict") + ErrMismatch = errors.New("idempotency mismatch") + ErrBodyMismatch = fmt.Errorf("request body mismatch: %w", ErrMismatch) + ErrMethodMismatch = fmt.Errorf("HTTP method mismatch: %w", ErrMismatch) + ErrURLMismatch = fmt.Errorf("URL mismatch: %w", ErrMismatch) + ErrHeaderMismatch = fmt.Errorf("Header mismatch: %w", ErrMismatch) + ErrInvalidKey = errors.New("invalid Idempotency-Key") + + criticalHeaders = []string{ + "Accept", + "Authorization", + } +) + +func NewPotency(handler http.Handler) *Potency { + return &Potency{ + handler: handler, + cache: map[string]*savedResult{}, + inProgress: map[string]bool{}, + } +} + +func (p *Potency) ServeHTTP(w http.ResponseWriter, r *http.Request) { + val := r.Header.Get("Idempotency-Key") + if val == "" { + p.handler.ServeHTTP(w, r) + return + } + + err := p.serveHTTP(w, r, val) + if err != nil { + jsrest.WriteError(w, err) + } +} + +func (p *Potency) serveHTTP(w http.ResponseWriter, r *http.Request, val string) error { + if len(val) < 2 || !strings.HasPrefix(val, `"`) || !strings.HasSuffix(val, `"`) { + return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", val, ErrInvalidKey) + } + + key := val[1 : len(val)-1] + + saved := p.read(key) + + if saved != nil { + if r.Method != saved.Method { + return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.Method, ErrMethodMismatch) + } + + if r.URL.String() != saved.URL { + return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.URL.String(), ErrURLMismatch) + } + + for _, h := range criticalHeaders { + if saved.RequestHeader.Get(h) != r.Header.Get(h) { + return jsrest.Errorf(jsrest.ErrBadRequest, "%s: %s (%w)", h, r.Header.Get(h), ErrHeaderMismatch) + } + } + + h := sha256.New() + + _, err := io.Copy(h, r.Body) + if err != nil { + return jsrest.Errorf(jsrest.ErrBadRequest, "hash request body failed (%w)", err) + } + + hexed := hex.EncodeToString(h.Sum(nil)) + if hexed != saved.Sha256 { + return jsrest.Errorf(jsrest.ErrBadRequest, "%s vs %s (%w)", hexed, saved.Sha256, ErrBodyMismatch) + } + + for key, vals := range saved.ResponseHeader { + w.Header().Set(key, vals[0]) + } + + w.WriteHeader(saved.StatusCode) + _, _ = w.Write(saved.ResponseBody) + + return nil + } + + // Store miss, proceed to normal execution with interception + err := p.lockKey(key) + if err != nil { + return jsrest.Errorf(jsrest.ErrConflict, "%s", key) + } + + defer p.unlockKey(key) + + requestHeader := http.Header{} + for _, h := range criticalHeaders { + requestHeader.Set(h, r.Header.Get(h)) + } + + bi := newBodyIntercept(r.Body) + r.Body = bi + + rwi := newResponseWriterIntercept(w) + w = rwi + + p.handler.ServeHTTP(w, r) + + save := &savedResult{ + Method: r.Method, + URL: r.URL.String(), + RequestHeader: requestHeader, + Sha256: hex.EncodeToString(bi.sha256.Sum(nil)), + + StatusCode: rwi.statusCode, + ResponseHeader: rwi.Header(), + ResponseBody: rwi.buf.Bytes(), + } + + p.write(key, save) + + return nil +} + +func (p *Potency) lockKey(key string) error { + p.inProgressMu.Lock() + defer p.inProgressMu.Unlock() + + if p.inProgress[key] { + return ErrConflict + } + + p.inProgress[key] = true + + return nil +} + +func (p *Potency) unlockKey(key string) { + p.inProgressMu.Lock() + defer p.inProgressMu.Unlock() + + delete(p.inProgress, key) +} + +func (p *Potency) read(key string) *savedResult { + p.cacheMu.RLock() + defer p.cacheMu.RUnlock() + + return p.cache[key] +} + +func (p *Potency) write(key string, sr *savedResult) { + p.cacheMu.Lock() + defer p.cacheMu.Unlock() + + p.cache[key] = sr +} diff --git a/potency_test.go b/potency_test.go new file mode 100644 index 0000000..1fdcc0d --- /dev/null +++ b/potency_test.go @@ -0,0 +1,183 @@ +package potency_test + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "testing" + "time" + + "github.com/dchest/uniuri" + "github.com/go-resty/resty/v2" + "github.com/gopatchy/potency" + "github.com/stretchr/testify/require" +) + +func TestGET(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + defer ts.shutdown(t) + + key1 := uniuri.New() + + resp, err := ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + Get("") + require.NoError(t, err) + require.False(t, resp.IsError()) + require.Equal(t, "bar", resp.Header().Get("X-Response")) + + resp1 := resp.String() + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + Get("") + require.NoError(t, err) + require.False(t, resp.IsError()) + require.Equal(t, "bar", resp.Header().Get("X-Response")) + require.Equal(t, resp1, resp.String()) + + key2 := uniuri.New() + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key2)). + Get("") + require.NoError(t, err) + require.False(t, resp.IsError()) + require.Equal(t, "bar", resp.Header().Get("X-Response")) + + resp2 := resp.String() + + require.NotEqual(t, resp2, resp1) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + Get("x") + require.NoError(t, err) + require.True(t, resp.IsError()) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + Delete("") + require.NoError(t, err) + require.True(t, resp.IsError()) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetHeader("Authorization", "Bearer xyz"). + Get("") + require.NoError(t, err) + require.True(t, resp.IsError()) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetHeader("Accept", "text/xml"). + Get("") + require.NoError(t, err) + require.True(t, resp.IsError()) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetHeader("X-Test", "foo"). + Get("") + require.NoError(t, err) + require.False(t, resp.IsError()) + require.Equal(t, "bar", resp.Header().Get("X-Response")) + require.Equal(t, resp1, resp.String()) +} + +func TestPOST(t *testing.T) { + t.Parallel() + + ts := newTestServer(t) + defer ts.shutdown(t) + + key1 := uniuri.New() + + resp, err := ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetBody("test1"). + Post("") + require.NoError(t, err) + require.False(t, resp.IsError()) + + resp1 := resp.String() + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetBody("test1"). + Post("") + require.NoError(t, err) + require.False(t, resp.IsError()) + require.Equal(t, resp1, resp.String()) + + resp, err = ts.r(). + SetHeader("Idempotency-Key", fmt.Sprintf(`"%s"`, key1)). + SetBody("test2"). + Post("") + require.NoError(t, err) + require.True(t, resp.IsError()) +} + +type testServer struct { + dir string + srv *http.Server + rst *resty.Client +} + +func newTestServer(t *testing.T) *testServer { + dir, err := os.MkdirTemp("", "") + require.NoError(t, err) + + mux := http.NewServeMux() + p := potency.NewPotency(mux) + + listener, err := net.Listen("tcp", "[::]:0") + require.NoError(t, err) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + require.NoError(t, err) + + w.Header().Add("X-Response", "bar") + + _, err = w.Write([]byte(uniuri.New())) + require.NoError(t, err) + }) + + srv := &http.Server{ + Handler: p, + ReadHeaderTimeout: 1 * time.Second, + } + + go func() { + _ = srv.Serve(listener) + }() + + baseURL := fmt.Sprintf("http://[::1]:%d/", listener.Addr().(*net.TCPAddr).Port) + + rst := resty.New(). + SetHeader("Content-Type", "application/json"). + SetBaseURL(baseURL) + + return &testServer{ + dir: dir, + srv: srv, + rst: rst, + } +} + +func (ts *testServer) r() *resty.Request { + return ts.rst.R() +} + +func (ts *testServer) shutdown(t *testing.T) { + err := ts.srv.Shutdown(context.Background()) + require.NoError(t, err) + + os.RemoveAll(ts.dir) +} diff --git a/responsewriterintercept.go b/responsewriterintercept.go new file mode 100644 index 0000000..3360f9d --- /dev/null +++ b/responsewriterintercept.go @@ -0,0 +1,34 @@ +package potency + +import ( + "bytes" + "net/http" +) + +type responseWriterIntercept struct { + dest http.ResponseWriter + buf bytes.Buffer + statusCode int +} + +func newResponseWriterIntercept(dest http.ResponseWriter) *responseWriterIntercept { + return &responseWriterIntercept{ + dest: dest, + buf: bytes.Buffer{}, + statusCode: http.StatusOK, + } +} + +func (rwi *responseWriterIntercept) Header() http.Header { + return rwi.dest.Header() +} + +func (rwi *responseWriterIntercept) Write(data []byte) (int, error) { + rwi.buf.Write(data) + return rwi.dest.Write(data) +} + +func (rwi *responseWriterIntercept) WriteHeader(statusCode int) { + rwi.statusCode = statusCode + rwi.dest.WriteHeader(statusCode) +}