fix: remove trailing null byte for local connection (#11295)

Fixes #633

I have written end-to-end tests against HAProxy in https://code.forgejo.org/forgejo/end-to-end/pulls/1578 and also written unit tests.

Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/11295
Reviewed-by: Gusted <gusted@noreply.codeberg.org>
Co-authored-by: famfo <famfo@famfo.xyz>
Co-committed-by: famfo <famfo@famfo.xyz>
This commit is contained in:
famfo 2026-02-16 05:55:50 +01:00 committed by Gusted
parent cf17b5fad9
commit 9767cebc42
5 changed files with 116 additions and 7 deletions

View file

@ -85,7 +85,6 @@ func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) er
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}
srv.listener = listener
@ -118,7 +117,6 @@ func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFun
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}
@ -130,7 +128,6 @@ func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFun
listener = &proxyprotocol.Listener{
Listener: listener,
ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
}

View file

@ -15,6 +15,7 @@ import (
"time"
"forgejo.org/modules/log"
"forgejo.org/modules/setting"
)
var (
@ -46,6 +47,7 @@ func NewConn(conn net.Conn, timeout time.Duration) *Conn {
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
acceptUnknown: setting.ProxyProtocolAcceptUnknown,
}
return pConn
}
@ -456,7 +458,7 @@ func (p *Conn) readV1ProxyHeader() error {
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.acceptUnknown || len(parts) != 2 {
if !p.acceptUnknown {
p.conn.Close()
return &ErrBadHeader{[]byte(header)}
}

View file

@ -0,0 +1,112 @@
// Copyright 2026 The Forgejo Authors. All rights reserved.
// SPDX-License-Identifier: GPL-3.0-or-later
package proxyprotocol_test
import (
"io"
"net"
"testing"
"time"
"forgejo.org/modules/proxyprotocol"
"forgejo.org/modules/setting"
"forgejo.org/modules/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var v2Header = []byte{0xd, 0xa, 0xd, 0xa, 0x0, 0xd, 0xa, 0x51, 0x55, 0x49, 0x54, 0xa}
func testConnection(t *testing.T, input []byte) *proxyprotocol.Conn {
local, remote := net.Pipe()
conn := proxyprotocol.NewConn(remote, 10*time.Second)
go func(t *testing.T, conn net.Conn) {
_, err := conn.Write(input)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}(t, local)
return conn
}
func assertUwu(t *testing.T, conn *proxyprotocol.Conn) {
buf := make([]byte, 3)
read, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 3, read)
assert.Equal(t, []byte("uwu"), buf)
}
func TestProxyProtocolParse(t *testing.T) {
// Basic v4/v6 TCP
ipv4Conn := testConnection(t, []byte("PROXY TCP4 7.3.3.1 1.3.3.7 14231 443\r\nuwu"))
assertUwu(t, ipv4Conn)
assert.Equal(t, "7.3.3.1:14231", ipv4Conn.RemoteAddr().String())
assert.Equal(t, "1.3.3.7:443", ipv4Conn.LocalAddr().String())
ipv6Conn := testConnection(t, []byte("PROXY TCP6 fe80::2 fe80::1 28512 443\r\nuwu"))
assertUwu(t, ipv6Conn)
assert.Equal(t, "[fe80::2]:28512", ipv6Conn.RemoteAddr().String())
assert.Equal(t, "[fe80::1]:443", ipv6Conn.LocalAddr().String())
ipv4Conn = testConnection(t, append(v2Header, 0x21, 0x11, 0x0, 0xc, 0x7, 0x3, 0x3, 0x1, 0x1, 0x3, 0x3, 0x7, 0xd9, 0xec, 0x1, 0xbb, 0x75, 0x77, 0x75))
assertUwu(t, ipv4Conn)
assert.Equal(t, "7.3.3.1:55788", ipv4Conn.RemoteAddr().String())
assert.Equal(t, "1.3.3.7:443", ipv4Conn.LocalAddr().String())
ipv6Conn = testConnection(t, append(v2Header, 0x21, 0x21, 0x0, 0x24, 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xd9, 0xec, 0x1, 0xbb, 0x75, 0x77, 0x75))
assertUwu(t, ipv6Conn)
assert.Equal(t, "[fe80::2]:55788", ipv6Conn.RemoteAddr().String())
assert.Equal(t, "[fe80::1]:443", ipv6Conn.LocalAddr().String())
// Basic unknown
unknownConn := testConnection(t, []byte("PROXY UNKNOWN\r\nuwu"))
_, err := unknownConn.Read([]byte{})
require.Error(t, err)
// Accept unknown protocol types
defer test.MockVariableValue(&setting.ProxyProtocolAcceptUnknown, true)()
unknownConn = testConnection(t, []byte("PROXY UNKNOWN\r\nuwu"))
assertUwu(t, unknownConn)
assert.Equal(t, "pipe", unknownConn.RemoteAddr().String())
assert.Equal(t, "pipe", unknownConn.LocalAddr().String())
// Discard any unknown information between "UNKNOWN" and CRLF
unknownConn = testConnection(t, []byte("PROXY UNKNOWN look, I'm hinding in an unknown protocol \\o/\r\nuwu"))
assertUwu(t, unknownConn)
assert.Equal(t, "pipe", unknownConn.RemoteAddr().String())
assert.Equal(t, "pipe", unknownConn.LocalAddr().String())
// Basic local
unknownConnV2 := testConnection(t, append(v2Header, 0x20, 0x0, 0x0, 0x0, 0x75, 0x77, 0x75))
assertUwu(t, unknownConnV2)
assert.Equal(t, "pipe", unknownConnV2.RemoteAddr().String())
assert.Equal(t, "pipe", unknownConnV2.LocalAddr().String())
}
func TestProxyProtocolInvalidHeader(t *testing.T) {
// Short prefix
conn := testConnection(t, []byte("PROXY\r\n"))
_, err := conn.Read([]byte{})
require.ErrorIs(t, err, io.EOF)
// Wrong prefix
conn = testConnection(t, []byte("PROXYv1337\r\n"))
_, err = conn.Read([]byte{})
require.ErrorContains(t, err, "Unexpected proxy header")
}

View file

@ -18,7 +18,6 @@ import (
type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
AcceptUnknown bool // allow PROXY UNKNOWN
}
// Accept implements the Accept method in the Listener interface
@ -31,7 +30,6 @@ func (p *Listener) Accept() (net.Conn, error) {
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.acceptUnknown = p.AcceptUnknown
return newConn, nil
}

View file

@ -5,7 +5,7 @@ package proxyprotocol
import "io"
var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00', '\x00')
var localHeader = append(v2Prefix, '\x20', '\x00', '\x00', '\x00')
// WriteLocalHeader will write the ProxyProtocol Header for a local connection to the provided writer
func WriteLocalHeader(w io.Writer) error {