Files
proxy/proxy.go

170 lines
3.6 KiB
Go
Raw Permalink Normal View History

2023-04-25 14:26:46 -07:00
package proxy
2023-04-25 15:18:35 -07:00
import (
"errors"
2023-06-07 23:24:32 -07:00
"fmt"
2023-04-25 15:18:35 -07:00
"net"
"sync"
"testing"
2023-06-07 23:24:32 -07:00
"github.com/stretchr/testify/require"
2023-04-25 15:18:35 -07:00
)
2023-04-25 14:26:46 -07:00
type Proxy struct {
2023-04-25 15:18:35 -07:00
t *testing.T
2023-04-25 14:26:46 -07:00
backend *net.TCPAddr
listener *net.TCPListener
2023-04-25 15:18:35 -07:00
2023-04-25 15:40:44 -07:00
conns map[*net.TCPConn]bool
refuse bool
mu sync.Mutex
2023-04-25 14:26:46 -07:00
}
2023-06-07 23:24:32 -07:00
func NewProxy(t *testing.T, backend *net.TCPAddr) *Proxy {
2023-04-25 14:26:46 -07:00
var err error
p := &Proxy{
2023-04-25 15:18:35 -07:00
t: t,
2023-04-25 14:26:46 -07:00
backend: backend,
2023-04-25 15:18:35 -07:00
conns: map[*net.TCPConn]bool{},
2023-04-25 14:26:46 -07:00
}
p.listener, err = net.ListenTCP("tcp", nil)
2023-06-07 23:24:32 -07:00
require.NoError(t, err)
2023-04-25 14:26:46 -07:00
go p.accept()
2023-04-25 15:18:35 -07:00
t.Logf("* -> %s -> [proxy] -> * -> %s listening...", p.listener.Addr(), p.backend)
2023-06-07 23:24:32 -07:00
return p
2023-04-25 14:26:46 -07:00
}
2023-04-25 15:18:35 -07:00
func (p *Proxy) Addr() *net.TCPAddr {
return p.listener.Addr().(*net.TCPAddr)
}
2023-06-07 23:24:32 -07:00
func (p *Proxy) HTTP() string {
return fmt.Sprintf("http://%s/", p.Addr())
}
func (p *Proxy) HTTPS() string {
return fmt.Sprintf("https://%s/", p.Addr())
}
2023-04-25 15:40:44 -07:00
func (p *Proxy) CloseAllConns() {
2023-04-25 15:18:35 -07:00
p.mu.Lock()
defer p.mu.Unlock()
2023-05-28 18:57:59 -07:00
p.t.Logf("* -> %s -> [proxy] -> * -> %s closing all connections", p.listener.Addr(), p.backend)
2023-04-25 15:18:35 -07:00
for conn := range p.conns {
conn.Close()
}
p.conns = map[*net.TCPConn]bool{}
2023-04-25 15:40:44 -07:00
}
func (p *Proxy) SetRefuse(refuse bool) {
p.mu.Lock()
defer p.mu.Unlock()
2023-05-28 18:57:59 -07:00
if refuse {
p.t.Logf("* -> %s -> [proxy] -> * -> %s refusing new connections", p.listener.Addr(), p.backend)
} else {
p.t.Logf("* -> %s -> [proxy] -> * -> %s accepting new connections", p.listener.Addr(), p.backend)
}
2023-04-25 15:40:44 -07:00
p.refuse = true
}
2023-05-28 18:57:59 -07:00
func (p *Proxy) SetBackend(backend *net.TCPAddr) {
p.mu.Lock()
defer p.mu.Unlock()
p.backend = backend
p.t.Logf("* -> %s -> [proxy] -> * -> %s switched to new backend", p.listener.Addr(), p.backend)
}
2023-04-25 15:40:44 -07:00
func (p *Proxy) Close() {
2023-05-28 18:57:59 -07:00
p.mu.Lock()
2023-04-25 15:40:44 -07:00
p.t.Logf("* -> %s -> [proxy] -> * -> %s closing...", p.listener.Addr(), p.backend)
2023-05-28 18:57:59 -07:00
p.mu.Unlock()
2023-04-25 15:40:44 -07:00
p.listener.Close()
p.CloseAllConns()
2023-04-25 15:18:35 -07:00
2023-05-28 18:57:59 -07:00
p.mu.Lock()
2023-04-25 15:18:35 -07:00
p.t.Logf("* -> %s -> [proxy] -> * -> %s closed", p.listener.Addr(), p.backend)
2023-05-28 18:57:59 -07:00
p.mu.Unlock()
2023-04-25 14:26:46 -07:00
}
func (p *Proxy) accept() {
for {
2023-04-25 15:18:35 -07:00
frontConn, err := p.listener.AcceptTCP()
2023-04-25 14:26:46 -07:00
if err != nil {
return
}
2023-04-25 15:40:44 -07:00
p.mu.Lock()
if p.refuse {
2023-05-28 18:57:59 -07:00
p.t.Logf("%s -> %s -> [proxy] -> * -> %s refused", frontConn.RemoteAddr(), frontConn.LocalAddr(), p.backend)
2023-04-25 15:40:44 -07:00
frontConn.Close()
} else {
go p.dial(frontConn)
}
p.mu.Unlock()
2023-04-25 15:18:35 -07:00
}
}
func (p *Proxy) dial(frontConn *net.TCPConn) {
p.t.Logf("%s -> %s -> [proxy] -> * -> %s dialing...", frontConn.RemoteAddr(), frontConn.LocalAddr(), p.backend)
backConn, err := net.DialTCP(p.backend.Network(), nil, p.backend)
if err != nil {
p.t.Logf("%s -> %s -> [proxy] -> * -> %s dialing failed: %s", frontConn.RemoteAddr(), frontConn.LocalAddr(), p.backend, err)
frontConn.Close()
return
}
p.t.Logf("%s -> %s -> [proxy] -> %s -> %s connected", frontConn.RemoteAddr(), frontConn.LocalAddr(), backConn.LocalAddr(), backConn.RemoteAddr())
p.addConns(frontConn, backConn)
go p.copy(frontConn, backConn)
go p.copy(backConn, frontConn)
}
func (p *Proxy) copy(src, dest *net.TCPConn) {
numBytes, err := dest.ReadFrom(src)
if err == nil || errors.Is(err, net.ErrClosed) {
p.t.Logf("%s -> %s -> [proxy] -> %s -> %s closed after %d bytes", src.RemoteAddr(), src.LocalAddr(), dest.LocalAddr(), dest.RemoteAddr(), numBytes)
} else {
p.t.Logf("%s -> %s -> [proxy] -> %s -> %s closed after %d bytes: %s", src.RemoteAddr(), src.LocalAddr(), dest.LocalAddr(), dest.RemoteAddr(), numBytes, err)
}
dest.Close()
p.delConns(src)
}
func (p *Proxy) addConns(conns ...*net.TCPConn) {
p.mu.Lock()
defer p.mu.Unlock()
for _, conn := range conns {
p.conns[conn] = true
}
}
func (p *Proxy) delConns(conns ...*net.TCPConn) {
p.mu.Lock()
defer p.mu.Unlock()
for _, conn := range conns {
delete(p.conns, conn)
2023-04-25 14:26:46 -07:00
}
}