mirror of
https://github.com/dalbodeule/hop-gate.git
synced 2025-12-12 06:40:11 +09:00
Merge pull request #18 from dalbodeule/copilot/fix-protobuf-length-prefix-framing
Fix DTLS protobuf codec for UDP datagram boundaries
This commit is contained in:
2
go.mod
2
go.mod
@@ -10,6 +10,7 @@ require (
|
||||
github.com/pion/dtls/v3 v3.0.7
|
||||
github.com/prometheus/client_golang v1.19.0
|
||||
golang.org/x/net v0.47.0
|
||||
google.golang.org/protobuf v1.36.10
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -40,5 +41,4 @@ require (
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@@ -32,8 +32,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
|
||||
github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
|
||||
@@ -51,7 +51,10 @@ func (jsonCodec) Decode(r io.Reader, env *Envelope) error {
|
||||
type protobufCodec struct{}
|
||||
|
||||
// Encode 는 Envelope 를 Protobuf Envelope 로 변환한 뒤, length-prefix 프레이밍으로 기록합니다.
|
||||
// DTLS는 UDP 기반이므로, length prefix와 protobuf 데이터를 단일 버퍼로 합쳐 하나의 Write로 전송합니다.
|
||||
// Encode encodes an Envelope as a length-prefixed protobuf message.
|
||||
// For DTLS (UDP-based), we combine the length prefix and protobuf data into a single buffer
|
||||
// and send it with a single Write call to preserve message boundaries.
|
||||
func (protobufCodec) Encode(w io.Writer, env *Envelope) error {
|
||||
pbEnv, err := toProtoEnvelope(env)
|
||||
if err != nil {
|
||||
@@ -83,58 +86,55 @@ func (protobufCodec) Encode(w io.Writer, env *Envelope) error {
|
||||
return fmt.Errorf("protobuf codec: empty marshaled envelope")
|
||||
}
|
||||
|
||||
var lenBuf [4]byte
|
||||
if len(data) > int(^uint32(0)) {
|
||||
return fmt.Errorf("protobuf codec: envelope too large: %d bytes", len(data))
|
||||
}
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(data)))
|
||||
|
||||
if _, err := w.Write(lenBuf[:]); err != nil {
|
||||
return fmt.Errorf("protobuf codec: write length prefix: %w", err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return fmt.Errorf("protobuf codec: write payload: %w", err)
|
||||
// DTLS 환경에서는 length prefix와 protobuf 데이터를 단일 버퍼로 합쳐서 하나의 Write로 전송
|
||||
// For DTLS, combine length prefix and protobuf data into a single buffer
|
||||
frame := make([]byte, 4+len(data))
|
||||
binary.BigEndian.PutUint32(frame[:4], uint32(len(data)))
|
||||
copy(frame[4:], data)
|
||||
|
||||
if _, err := w.Write(frame); err != nil {
|
||||
return fmt.Errorf("protobuf codec: write frame: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode 는 length-prefix 프레임에서 Protobuf Envelope 를 읽어들여
|
||||
// 내부 Envelope 구조체로 변환합니다.
|
||||
// DTLS는 UDP 기반이므로, 한 번의 Read로 전체 데이터그램을 읽습니다.
|
||||
// Decode reads a length-prefixed protobuf Envelope and converts it into the internal Envelope.
|
||||
// For DTLS (UDP-based), we read the entire datagram in a single Read call.
|
||||
func (protobufCodec) Decode(r io.Reader, env *Envelope) error {
|
||||
// IMPORTANT:
|
||||
// pion/dtls 는 복호화된 애플리케이션 데이터를 호출자가 제공한 버퍼에 한 번에 채웁니다.
|
||||
// 너무 작은 버퍼(예: 4바이트 len prefix)로 직접 Read 를 호출하면
|
||||
// "dtls: buffer is too small" (temporary) 에러가 발생할 수 있습니다.
|
||||
//
|
||||
// 이를 피하기 위해, DTLS 세션 위에서는 항상 충분히 큰 bufio.Reader 로 래핑한 뒤
|
||||
// io.ReadFull 을 사용합니다. 이렇게 하면 하위 DTLS Conn.Read 는
|
||||
// 내부 버퍼 크기(defaultDecoderBufferSize, 64KiB)만큼 읽고,
|
||||
// 그 위에서 length-prefix 를 안전하게 처리할 수 있습니다.
|
||||
br, ok := r.(*bufio.Reader)
|
||||
if !ok {
|
||||
br = bufio.NewReaderSize(r, defaultDecoderBufferSize)
|
||||
// DTLS는 메시지 경계가 보존되는 UDP 기반 프로토콜입니다.
|
||||
// 한 번의 Read로 전체 데이터그램(length prefix + protobuf data)을 읽어야 합니다.
|
||||
// DTLS is a UDP-based protocol that preserves message boundaries.
|
||||
// We must read the entire datagram (length prefix + protobuf data) in a single Read call.
|
||||
buf := make([]byte, maxProtoEnvelopeBytes+4)
|
||||
n, err := r.Read(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("protobuf codec: read frame: %w", err)
|
||||
}
|
||||
if n < 4 {
|
||||
return fmt.Errorf("protobuf codec: frame too short: %d bytes", n)
|
||||
}
|
||||
|
||||
var lenBuf [4]byte
|
||||
if _, err := io.ReadFull(br, lenBuf[:]); err != nil {
|
||||
return fmt.Errorf("protobuf codec: read length prefix: %w", err)
|
||||
}
|
||||
n := binary.BigEndian.Uint32(lenBuf[:])
|
||||
if n == 0 {
|
||||
// Extract and validate the length prefix
|
||||
length := binary.BigEndian.Uint32(buf[:4])
|
||||
if length == 0 {
|
||||
return fmt.Errorf("protobuf codec: zero-length envelope")
|
||||
}
|
||||
if n > maxProtoEnvelopeBytes {
|
||||
return fmt.Errorf("protobuf codec: envelope too large: %d bytes (max %d)", n, maxProtoEnvelopeBytes)
|
||||
if length > maxProtoEnvelopeBytes {
|
||||
return fmt.Errorf("protobuf codec: envelope too large: %d bytes (max %d)", length, maxProtoEnvelopeBytes)
|
||||
}
|
||||
|
||||
buf := make([]byte, int(n))
|
||||
if _, err := io.ReadFull(br, buf); err != nil {
|
||||
return fmt.Errorf("protobuf codec: read payload: %w", err)
|
||||
if int(length) != n-4 {
|
||||
return fmt.Errorf("protobuf codec: length mismatch: expected %d, got %d", length, n-4)
|
||||
}
|
||||
|
||||
var pbEnv protocolpb.Envelope
|
||||
if err := proto.Unmarshal(buf, &pbEnv); err != nil {
|
||||
if err := proto.Unmarshal(buf[4:n], &pbEnv); err != nil {
|
||||
return fmt.Errorf("protobuf codec: unmarshal envelope: %w", err)
|
||||
}
|
||||
|
||||
|
||||
221
internal/protocol/codec_test.go
Normal file
221
internal/protocol/codec_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockDatagramConn simulates a datagram-based connection (like DTLS over UDP)
|
||||
// where each Write sends a separate message and each Read receives a complete message.
|
||||
// This mock verifies the FIXED behavior where the codec properly handles message boundaries.
|
||||
type mockDatagramConn struct {
|
||||
messages [][]byte
|
||||
readIdx int
|
||||
}
|
||||
|
||||
func newMockDatagramConn() *mockDatagramConn {
|
||||
return &mockDatagramConn{
|
||||
messages: make([][]byte, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDatagramConn) Write(p []byte) (n int, err error) {
|
||||
// Simulate datagram behavior: each Write is a separate message
|
||||
msg := make([]byte, len(p))
|
||||
copy(msg, p)
|
||||
m.messages = append(m.messages, msg)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockDatagramConn) Read(p []byte) (n int, err error) {
|
||||
// Simulate datagram behavior: each Read returns a complete message
|
||||
if m.readIdx >= len(m.messages) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
msg := m.messages[m.readIdx]
|
||||
m.readIdx++
|
||||
if len(p) < len(msg) {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
copy(p, msg)
|
||||
return len(msg), nil
|
||||
}
|
||||
|
||||
// TestProtobufCodecDatagramBehavior tests that the protobuf codec works correctly
|
||||
// with datagram-based transports (like DTLS over UDP) where message boundaries are preserved.
|
||||
func TestProtobufCodecDatagramBehavior(t *testing.T) {
|
||||
codec := protobufCodec{}
|
||||
conn := newMockDatagramConn()
|
||||
|
||||
// Create a test envelope
|
||||
testEnv := &Envelope{
|
||||
Type: MessageTypeHTTP,
|
||||
HTTPRequest: &Request{
|
||||
RequestID: "test-req-123",
|
||||
ClientID: "client-1",
|
||||
ServiceName: "test-service",
|
||||
Method: "GET",
|
||||
URL: "/test/path",
|
||||
Header: map[string][]string{
|
||||
"User-Agent": {"test-client"},
|
||||
},
|
||||
Body: []byte("test body content"),
|
||||
},
|
||||
}
|
||||
|
||||
// Encode the envelope
|
||||
if err := codec.Encode(conn, testEnv); err != nil {
|
||||
t.Fatalf("Failed to encode envelope: %v", err)
|
||||
}
|
||||
|
||||
// Verify that exactly one message was written (length prefix + data in single Write)
|
||||
if len(conn.messages) != 1 {
|
||||
t.Fatalf("Expected 1 message to be written, got %d", len(conn.messages))
|
||||
}
|
||||
|
||||
// Verify the message structure: [4-byte length][protobuf data]
|
||||
msg := conn.messages[0]
|
||||
if len(msg) < 4 {
|
||||
t.Fatalf("Message too short: %d bytes", len(msg))
|
||||
}
|
||||
|
||||
// Decode the envelope
|
||||
var decodedEnv Envelope
|
||||
if err := codec.Decode(conn, &decodedEnv); err != nil {
|
||||
t.Fatalf("Failed to decode envelope: %v", err)
|
||||
}
|
||||
|
||||
// Verify the decoded envelope matches the original
|
||||
if decodedEnv.Type != testEnv.Type {
|
||||
t.Errorf("Type mismatch: got %v, want %v", decodedEnv.Type, testEnv.Type)
|
||||
}
|
||||
if decodedEnv.HTTPRequest == nil {
|
||||
t.Fatal("HTTPRequest is nil after decode")
|
||||
}
|
||||
if decodedEnv.HTTPRequest.RequestID != testEnv.HTTPRequest.RequestID {
|
||||
t.Errorf("RequestID mismatch: got %v, want %v", decodedEnv.HTTPRequest.RequestID, testEnv.HTTPRequest.RequestID)
|
||||
}
|
||||
if decodedEnv.HTTPRequest.Method != testEnv.HTTPRequest.Method {
|
||||
t.Errorf("Method mismatch: got %v, want %v", decodedEnv.HTTPRequest.Method, testEnv.HTTPRequest.Method)
|
||||
}
|
||||
if decodedEnv.HTTPRequest.URL != testEnv.HTTPRequest.URL {
|
||||
t.Errorf("URL mismatch: got %v, want %v", decodedEnv.HTTPRequest.URL, testEnv.HTTPRequest.URL)
|
||||
}
|
||||
if !bytes.Equal(decodedEnv.HTTPRequest.Body, testEnv.HTTPRequest.Body) {
|
||||
t.Errorf("Body mismatch: got %v, want %v", decodedEnv.HTTPRequest.Body, testEnv.HTTPRequest.Body)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtobufCodecStreamData tests encoding/decoding of StreamData messages
|
||||
func TestProtobufCodecStreamData(t *testing.T) {
|
||||
codec := protobufCodec{}
|
||||
conn := newMockDatagramConn()
|
||||
|
||||
// Create a StreamData envelope
|
||||
testEnv := &Envelope{
|
||||
Type: MessageTypeStreamData,
|
||||
StreamData: &StreamData{
|
||||
ID: StreamID("stream-123"),
|
||||
Seq: 42,
|
||||
Data: []byte("stream data payload"),
|
||||
},
|
||||
}
|
||||
|
||||
// Encode
|
||||
if err := codec.Encode(conn, testEnv); err != nil {
|
||||
t.Fatalf("Failed to encode StreamData: %v", err)
|
||||
}
|
||||
|
||||
// Verify single message
|
||||
if len(conn.messages) != 1 {
|
||||
t.Fatalf("Expected 1 message, got %d", len(conn.messages))
|
||||
}
|
||||
|
||||
// Decode
|
||||
var decodedEnv Envelope
|
||||
if err := codec.Decode(conn, &decodedEnv); err != nil {
|
||||
t.Fatalf("Failed to decode StreamData: %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
if decodedEnv.Type != MessageTypeStreamData {
|
||||
t.Errorf("Type mismatch: got %v, want %v", decodedEnv.Type, MessageTypeStreamData)
|
||||
}
|
||||
if decodedEnv.StreamData == nil {
|
||||
t.Fatal("StreamData is nil")
|
||||
}
|
||||
if decodedEnv.StreamData.ID != testEnv.StreamData.ID {
|
||||
t.Errorf("StreamID mismatch: got %v, want %v", decodedEnv.StreamData.ID, testEnv.StreamData.ID)
|
||||
}
|
||||
if decodedEnv.StreamData.Seq != testEnv.StreamData.Seq {
|
||||
t.Errorf("Seq mismatch: got %v, want %v", decodedEnv.StreamData.Seq, testEnv.StreamData.Seq)
|
||||
}
|
||||
if !bytes.Equal(decodedEnv.StreamData.Data, testEnv.StreamData.Data) {
|
||||
t.Errorf("Data mismatch: got %v, want %v", decodedEnv.StreamData.Data, testEnv.StreamData.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtobufCodecMultipleMessages tests encoding/decoding multiple messages
|
||||
func TestProtobufCodecMultipleMessages(t *testing.T) {
|
||||
codec := protobufCodec{}
|
||||
conn := newMockDatagramConn()
|
||||
|
||||
// Create multiple test envelopes
|
||||
envelopes := []*Envelope{
|
||||
{
|
||||
Type: MessageTypeStreamOpen,
|
||||
StreamOpen: &StreamOpen{
|
||||
ID: StreamID("stream-1"),
|
||||
Service: "test-service",
|
||||
TargetAddr: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: MessageTypeStreamData,
|
||||
StreamData: &StreamData{
|
||||
ID: StreamID("stream-1"),
|
||||
Seq: 1,
|
||||
Data: []byte("first chunk"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: MessageTypeStreamData,
|
||||
StreamData: &StreamData{
|
||||
ID: StreamID("stream-1"),
|
||||
Seq: 2,
|
||||
Data: []byte("second chunk"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: MessageTypeStreamClose,
|
||||
StreamClose: &StreamClose{
|
||||
ID: StreamID("stream-1"),
|
||||
Error: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Encode all messages
|
||||
for i, env := range envelopes {
|
||||
if err := codec.Encode(conn, env); err != nil {
|
||||
t.Fatalf("Failed to encode message %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that each encode produced exactly one message
|
||||
if len(conn.messages) != len(envelopes) {
|
||||
t.Fatalf("Expected %d messages, got %d", len(envelopes), len(conn.messages))
|
||||
}
|
||||
|
||||
// Decode and verify all messages
|
||||
for i := 0; i < len(envelopes); i++ {
|
||||
var decoded Envelope
|
||||
if err := codec.Decode(conn, &decoded); err != nil {
|
||||
t.Fatalf("Failed to decode message %d: %v", i, err)
|
||||
}
|
||||
if decoded.Type != envelopes[i].Type {
|
||||
t.Errorf("Message %d type mismatch: got %v, want %v", i, decoded.Type, envelopes[i].Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user