diff --git a/modules/graceful/server.go b/modules/graceful/server.go index e812117dbd..0f547a6c30 100644 --- a/modules/graceful/server.go +++ b/modules/graceful/server.go @@ -85,7 +85,6 @@ func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) er listener = &proxyprotocol.Listener{ Listener: listener, ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, - AcceptUnknown: setting.ProxyProtocolAcceptUnknown, } } srv.listener = listener @@ -118,7 +117,6 @@ func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFun listener = &proxyprotocol.Listener{ Listener: listener, ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, - AcceptUnknown: setting.ProxyProtocolAcceptUnknown, } } @@ -130,7 +128,6 @@ func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFun listener = &proxyprotocol.Listener{ Listener: listener, ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout, - AcceptUnknown: setting.ProxyProtocolAcceptUnknown, } } diff --git a/modules/proxyprotocol/conn.go b/modules/proxyprotocol/conn.go index beac5de120..9434fa88a1 100644 --- a/modules/proxyprotocol/conn.go +++ b/modules/proxyprotocol/conn.go @@ -15,6 +15,7 @@ import ( "time" "forgejo.org/modules/log" + "forgejo.org/modules/setting" ) var ( @@ -46,6 +47,7 @@ func NewConn(conn net.Conn, timeout time.Duration) *Conn { bufReader: bufio.NewReader(conn), conn: conn, proxyHeaderTimeout: timeout, + acceptUnknown: setting.ProxyProtocolAcceptUnknown, } return pConn } @@ -456,7 +458,7 @@ func (p *Conn) readV1ProxyHeader() error { // Verify the type is known switch parts[1] { case "UNKNOWN": - if !p.acceptUnknown || len(parts) != 2 { + if !p.acceptUnknown { p.conn.Close() return &ErrBadHeader{[]byte(header)} } diff --git a/modules/proxyprotocol/conn_test.go b/modules/proxyprotocol/conn_test.go new file mode 100644 index 0000000000..39adc401a3 --- /dev/null +++ b/modules/proxyprotocol/conn_test.go @@ -0,0 +1,112 @@ +// Copyright 2026 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later + +package proxyprotocol_test + +import ( + "io" + "net" + "testing" + "time" + + "forgejo.org/modules/proxyprotocol" + "forgejo.org/modules/setting" + "forgejo.org/modules/test" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var v2Header = []byte{0xd, 0xa, 0xd, 0xa, 0x0, 0xd, 0xa, 0x51, 0x55, 0x49, 0x54, 0xa} + +func testConnection(t *testing.T, input []byte) *proxyprotocol.Conn { + local, remote := net.Pipe() + conn := proxyprotocol.NewConn(remote, 10*time.Second) + + go func(t *testing.T, conn net.Conn) { + _, err := conn.Write(input) + require.NoError(t, err) + + err = conn.Close() + require.NoError(t, err) + }(t, local) + + return conn +} + +func assertUwu(t *testing.T, conn *proxyprotocol.Conn) { + buf := make([]byte, 3) + read, err := conn.Read(buf) + require.NoError(t, err) + + assert.Equal(t, 3, read) + assert.Equal(t, []byte("uwu"), buf) +} + +func TestProxyProtocolParse(t *testing.T) { + // Basic v4/v6 TCP + ipv4Conn := testConnection(t, []byte("PROXY TCP4 7.3.3.1 1.3.3.7 14231 443\r\nuwu")) + + assertUwu(t, ipv4Conn) + assert.Equal(t, "7.3.3.1:14231", ipv4Conn.RemoteAddr().String()) + assert.Equal(t, "1.3.3.7:443", ipv4Conn.LocalAddr().String()) + + ipv6Conn := testConnection(t, []byte("PROXY TCP6 fe80::2 fe80::1 28512 443\r\nuwu")) + + assertUwu(t, ipv6Conn) + assert.Equal(t, "[fe80::2]:28512", ipv6Conn.RemoteAddr().String()) + assert.Equal(t, "[fe80::1]:443", ipv6Conn.LocalAddr().String()) + + ipv4Conn = testConnection(t, append(v2Header, 0x21, 0x11, 0x0, 0xc, 0x7, 0x3, 0x3, 0x1, 0x1, 0x3, 0x3, 0x7, 0xd9, 0xec, 0x1, 0xbb, 0x75, 0x77, 0x75)) + + assertUwu(t, ipv4Conn) + assert.Equal(t, "7.3.3.1:55788", ipv4Conn.RemoteAddr().String()) + assert.Equal(t, "1.3.3.7:443", ipv4Conn.LocalAddr().String()) + + ipv6Conn = testConnection(t, append(v2Header, 0x21, 0x21, 0x0, 0x24, 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xd9, 0xec, 0x1, 0xbb, 0x75, 0x77, 0x75)) + + assertUwu(t, ipv6Conn) + assert.Equal(t, "[fe80::2]:55788", ipv6Conn.RemoteAddr().String()) + assert.Equal(t, "[fe80::1]:443", ipv6Conn.LocalAddr().String()) + + // Basic unknown + unknownConn := testConnection(t, []byte("PROXY UNKNOWN\r\nuwu")) + _, err := unknownConn.Read([]byte{}) + require.Error(t, err) + + // Accept unknown protocol types + defer test.MockVariableValue(&setting.ProxyProtocolAcceptUnknown, true)() + + unknownConn = testConnection(t, []byte("PROXY UNKNOWN\r\nuwu")) + + assertUwu(t, unknownConn) + assert.Equal(t, "pipe", unknownConn.RemoteAddr().String()) + assert.Equal(t, "pipe", unknownConn.LocalAddr().String()) + + // Discard any unknown information between "UNKNOWN" and CRLF + + unknownConn = testConnection(t, []byte("PROXY UNKNOWN look, I'm hinding in an unknown protocol \\o/\r\nuwu")) + + assertUwu(t, unknownConn) + assert.Equal(t, "pipe", unknownConn.RemoteAddr().String()) + assert.Equal(t, "pipe", unknownConn.LocalAddr().String()) + + // Basic local + unknownConnV2 := testConnection(t, append(v2Header, 0x20, 0x0, 0x0, 0x0, 0x75, 0x77, 0x75)) + + assertUwu(t, unknownConnV2) + assert.Equal(t, "pipe", unknownConnV2.RemoteAddr().String()) + assert.Equal(t, "pipe", unknownConnV2.LocalAddr().String()) +} + +func TestProxyProtocolInvalidHeader(t *testing.T) { + // Short prefix + conn := testConnection(t, []byte("PROXY\r\n")) + _, err := conn.Read([]byte{}) + require.ErrorIs(t, err, io.EOF) + + // Wrong prefix + conn = testConnection(t, []byte("PROXYv1337\r\n")) + _, err = conn.Read([]byte{}) + require.ErrorContains(t, err, "Unexpected proxy header") +} diff --git a/modules/proxyprotocol/listener.go b/modules/proxyprotocol/listener.go index ec85c425d3..500181cd6a 100644 --- a/modules/proxyprotocol/listener.go +++ b/modules/proxyprotocol/listener.go @@ -18,7 +18,6 @@ import ( type Listener struct { Listener net.Listener ProxyHeaderTimeout time.Duration - AcceptUnknown bool // allow PROXY UNKNOWN } // Accept implements the Accept method in the Listener interface @@ -31,7 +30,6 @@ func (p *Listener) Accept() (net.Conn, error) { } newConn := NewConn(conn, p.ProxyHeaderTimeout) - newConn.acceptUnknown = p.AcceptUnknown return newConn, nil } diff --git a/modules/proxyprotocol/util.go b/modules/proxyprotocol/util.go index a280663b27..5fc89d80a9 100644 --- a/modules/proxyprotocol/util.go +++ b/modules/proxyprotocol/util.go @@ -5,7 +5,7 @@ package proxyprotocol import "io" -var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00', '\x00') +var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00') // WriteLocalHeader will write the ProxyProtocol Header for a local connection to the provided writer func WriteLocalHeader(w io.Writer) error {