239 lines
4.7 KiB
Go
239 lines
4.7 KiB
Go
package potency
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gopatchy/jsrest"
|
|
)
|
|
|
|
type Potency struct {
|
|
handler http.Handler
|
|
|
|
lifetime time.Duration
|
|
|
|
cache map[string]*savedResult
|
|
cacheOldest *savedResult
|
|
cacheNewest *savedResult
|
|
cacheMu sync.RWMutex
|
|
|
|
inProgress map[string]bool
|
|
inProgressMu sync.Mutex
|
|
}
|
|
|
|
type savedResult struct {
|
|
key string
|
|
|
|
method string
|
|
url string
|
|
requestHeader http.Header
|
|
sha256 []byte
|
|
|
|
statusCode int
|
|
responseHeader http.Header
|
|
responseBody []byte
|
|
|
|
added time.Time
|
|
newer *savedResult
|
|
}
|
|
|
|
var (
|
|
ErrConflict = errors.New("conflict")
|
|
ErrMismatch = errors.New("idempotency mismatch")
|
|
ErrBodyMismatch = fmt.Errorf("request body mismatch: %w", ErrMismatch)
|
|
ErrMethodMismatch = fmt.Errorf("HTTP method mismatch: %w", ErrMismatch)
|
|
ErrURLMismatch = fmt.Errorf("URL mismatch: %w", ErrMismatch)
|
|
ErrHeaderMismatch = fmt.Errorf("Header mismatch: %w", ErrMismatch)
|
|
ErrInvalidKey = errors.New("invalid Idempotency-Key")
|
|
|
|
criticalHeaders = []string{
|
|
"Accept",
|
|
"Authorization",
|
|
}
|
|
)
|
|
|
|
func NewPotency(handler http.Handler) *Potency {
|
|
return &Potency{
|
|
handler: handler,
|
|
lifetime: 6 * time.Hour,
|
|
cache: map[string]*savedResult{},
|
|
inProgress: map[string]bool{},
|
|
}
|
|
}
|
|
|
|
func (p *Potency) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
val := r.Header.Get("Idempotency-Key")
|
|
if val == "" {
|
|
p.handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
err := p.serveHTTP(w, r, val)
|
|
if err != nil {
|
|
jsrest.WriteError(w, err)
|
|
}
|
|
}
|
|
|
|
func (p *Potency) SetLifetime(lifetime time.Duration) {
|
|
p.cacheMu.Lock()
|
|
defer p.cacheMu.Unlock()
|
|
|
|
p.lifetime = lifetime
|
|
}
|
|
|
|
func (p *Potency) NumCached() int {
|
|
p.cacheMu.RLock()
|
|
defer p.cacheMu.RUnlock()
|
|
|
|
return len(p.cache)
|
|
}
|
|
|
|
func (p *Potency) serveHTTP(w http.ResponseWriter, r *http.Request, val string) error {
|
|
if len(val) < 2 || !strings.HasPrefix(val, `"`) || !strings.HasSuffix(val, `"`) {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", val, ErrInvalidKey)
|
|
}
|
|
|
|
key := val[1 : len(val)-1]
|
|
|
|
saved := p.read(key)
|
|
|
|
if saved != nil {
|
|
if r.Method != saved.method {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.Method, ErrMethodMismatch)
|
|
}
|
|
|
|
if r.URL.String() != saved.url {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "%s (%w)", r.URL.String(), ErrURLMismatch)
|
|
}
|
|
|
|
for _, h := range criticalHeaders {
|
|
if saved.requestHeader.Get(h) != r.Header.Get(h) {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "%s: %s (%w)", h, r.Header.Get(h), ErrHeaderMismatch)
|
|
}
|
|
}
|
|
|
|
h := sha256.New()
|
|
|
|
_, err := io.Copy(h, r.Body)
|
|
if err != nil {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "hash request body failed (%w)", err)
|
|
}
|
|
|
|
sha256 := h.Sum(nil)
|
|
if !bytes.Equal(sha256, saved.sha256) {
|
|
return jsrest.Errorf(jsrest.ErrBadRequest, "%s vs %s (%w)", sha256, saved.sha256, ErrBodyMismatch)
|
|
}
|
|
|
|
for key, vals := range saved.responseHeader {
|
|
w.Header().Set(key, vals[0])
|
|
}
|
|
|
|
w.WriteHeader(saved.statusCode)
|
|
_, _ = w.Write(saved.responseBody)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Store miss, proceed to normal execution with interception
|
|
err := p.lockKey(key)
|
|
if err != nil {
|
|
return jsrest.Errorf(jsrest.ErrConflict, "%s", key)
|
|
}
|
|
|
|
defer p.unlockKey(key)
|
|
|
|
requestHeader := http.Header{}
|
|
for _, h := range criticalHeaders {
|
|
requestHeader.Set(h, r.Header.Get(h))
|
|
}
|
|
|
|
bi := newBodyIntercept(r.Body)
|
|
r.Body = bi
|
|
|
|
rwi := newResponseWriterIntercept(w)
|
|
w = rwi
|
|
|
|
p.handler.ServeHTTP(w, r)
|
|
|
|
save := &savedResult{
|
|
key: key,
|
|
|
|
method: r.Method,
|
|
url: r.URL.String(),
|
|
requestHeader: requestHeader,
|
|
sha256: bi.sha256.Sum(nil),
|
|
|
|
statusCode: rwi.statusCode,
|
|
responseHeader: rwi.Header(),
|
|
responseBody: rwi.buf.Bytes(),
|
|
}
|
|
|
|
p.write(save)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Potency) lockKey(key string) error {
|
|
p.inProgressMu.Lock()
|
|
defer p.inProgressMu.Unlock()
|
|
|
|
if p.inProgress[key] {
|
|
return ErrConflict
|
|
}
|
|
|
|
p.inProgress[key] = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Potency) unlockKey(key string) {
|
|
p.inProgressMu.Lock()
|
|
defer p.inProgressMu.Unlock()
|
|
|
|
delete(p.inProgress, key)
|
|
}
|
|
|
|
func (p *Potency) read(key string) *savedResult {
|
|
p.cacheMu.RLock()
|
|
defer p.cacheMu.RUnlock()
|
|
|
|
return p.cache[key]
|
|
}
|
|
|
|
func (p *Potency) write(sr *savedResult) {
|
|
p.cacheMu.Lock()
|
|
defer p.cacheMu.Unlock()
|
|
|
|
sr.added = time.Now()
|
|
|
|
p.cache[sr.key] = sr
|
|
|
|
if p.cacheNewest != nil {
|
|
p.cacheNewest.newer = sr
|
|
}
|
|
|
|
p.cacheNewest = sr
|
|
|
|
if p.cacheOldest == nil {
|
|
p.cacheOldest = sr
|
|
}
|
|
|
|
p.removeExpired()
|
|
}
|
|
|
|
func (p *Potency) removeExpired() {
|
|
cutoff := time.Now().Add(-1 * p.lifetime)
|
|
|
|
for iter := p.cacheOldest; iter != nil && iter.added.Before(cutoff); iter = iter.newer {
|
|
delete(p.cache, iter.key)
|
|
p.cacheOldest = iter
|
|
}
|
|
}
|