From 4278afcbc4e07742e3901e412463f8d6efeba2ce Mon Sep 17 00:00:00 2001 From: Ian Gulliver Date: Tue, 25 Apr 2023 15:18:35 -0700 Subject: [PATCH] Working two-way proxy --- go.sum | 0 proxy.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 go.sum diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/proxy.go b/proxy.go index c51d237..7e0df0b 100644 --- a/proxy.go +++ b/proxy.go @@ -1,17 +1,28 @@ package proxy -import "net" +import ( + "errors" + "net" + "sync" + "testing" +) type Proxy struct { + t *testing.T backend *net.TCPAddr listener *net.TCPListener + + conns map[*net.TCPConn]bool + mu sync.Mutex } -func NewProxy(backend *net.TCPAddr) (*Proxy, error) { +func NewProxy(t *testing.T, backend *net.TCPAddr) (*Proxy, error) { var err error p := &Proxy{ + t: t, backend: backend, + conns: map[*net.TCPConn]bool{}, } p.listener, err = net.ListenTCP("tcp", nil) @@ -21,20 +32,88 @@ func NewProxy(backend *net.TCPAddr) (*Proxy, error) { go p.accept() + t.Logf("* -> %s -> [proxy] -> * -> %s listening...", p.listener.Addr(), p.backend) + return p, nil } +func (p *Proxy) Addr() *net.TCPAddr { + return p.listener.Addr().(*net.TCPAddr) +} + func (p *Proxy) Close() { + p.t.Logf("* -> %s -> [proxy] -> * -> %s closing...", p.listener.Addr(), p.backend) + p.listener.Close() + + p.mu.Lock() + defer p.mu.Unlock() + + for conn := range p.conns { + conn.Close() + } + + p.conns = map[*net.TCPConn]bool{} + + p.t.Logf("* -> %s -> [proxy] -> * -> %s closed", p.listener.Addr(), p.backend) } func (p *Proxy) accept() { for { - conn, err := p.listener.Accept() + frontConn, err := p.listener.AcceptTCP() if err != nil { return } - conn.Close() + go p.dial(frontConn) + } +} + +func (p *Proxy) dial(frontConn *net.TCPConn) { + p.t.Logf("%s -> %s -> [proxy] -> * -> %s dialing...", frontConn.RemoteAddr(), frontConn.LocalAddr(), p.backend) + + backConn, err := net.DialTCP(p.backend.Network(), nil, p.backend) + if err != nil { + p.t.Logf("%s -> %s -> [proxy] -> * -> %s dialing failed: %s", frontConn.RemoteAddr(), frontConn.LocalAddr(), p.backend, err) + frontConn.Close() + return + } + + p.t.Logf("%s -> %s -> [proxy] -> %s -> %s connected", frontConn.RemoteAddr(), frontConn.LocalAddr(), backConn.LocalAddr(), backConn.RemoteAddr()) + + p.addConns(frontConn, backConn) + + go p.copy(frontConn, backConn) + go p.copy(backConn, frontConn) +} + +func (p *Proxy) copy(src, dest *net.TCPConn) { + numBytes, err := dest.ReadFrom(src) + + if err == nil || errors.Is(err, net.ErrClosed) { + p.t.Logf("%s -> %s -> [proxy] -> %s -> %s closed after %d bytes", src.RemoteAddr(), src.LocalAddr(), dest.LocalAddr(), dest.RemoteAddr(), numBytes) + } else { + p.t.Logf("%s -> %s -> [proxy] -> %s -> %s closed after %d bytes: %s", src.RemoteAddr(), src.LocalAddr(), dest.LocalAddr(), dest.RemoteAddr(), numBytes, err) + } + + dest.Close() + p.delConns(src) +} + +func (p *Proxy) addConns(conns ...*net.TCPConn) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, conn := range conns { + p.conns[conn] = true + } +} + +func (p *Proxy) delConns(conns ...*net.TCPConn) { + p.mu.Lock() + defer p.mu.Unlock() + + for _, conn := range conns { + delete(p.conns, conn) } }