Fix concurrent request handling with stream multiplexing

- Add channel-based multiplexing to handle concurrent HTTP requests
- Implement background readLoop to dispatch responses to correct streams
- Remove mutex bottleneck that was serializing all requests
- Fixes "unexpected stream_data/stream_open for id" errors with concurrent requests

Co-authored-by: dalbodeule <11470513+dalbodeule@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-09 14:22:29 +00:00
parent 1292df33e5
commit ff38ef2828

View File

@@ -34,11 +34,31 @@ import (
// 기본값 "dev" 는 로컬 개발용입니다. // 기본값 "dev" 는 로컬 개발용입니다.
var version = "dev" var version = "dev"
// streamResponse collects the complete response for a single HTTP stream request
type streamResponse struct {
statusCode int
header map[string][]string
body bytes.Buffer
err error
}
// pendingRequest tracks a request waiting for its response
type pendingRequest struct {
streamID protocol.StreamID
respCh chan *protocol.Envelope
doneCh chan struct{}
}
type dtlsSessionWrapper struct { type dtlsSessionWrapper struct {
sess dtls.Session sess dtls.Session
bufferedReader *bufio.Reader bufferedReader *bufio.Reader
mu sync.Mutex codec protocol.WireCodec
nextStreamID uint64 logger logging.Logger
mu sync.Mutex
nextStreamID uint64
pending map[protocol.StreamID]*pendingRequest
readerDone chan struct{}
} }
func getEnvOrPanic(logger logging.Logger, key string) string { func getEnvOrPanic(logger logging.Logger, key string) string {
@@ -176,22 +196,115 @@ func parseExpectedIPsFromEnv(logger logging.Logger, envKey string) []net.IP {
// ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고, // ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고,
// 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko) // 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko)
// readLoop continuously reads from the DTLS session and dispatches incoming frames
// to the appropriate pending request based on stream ID
func (w *dtlsSessionWrapper) readLoop() {
defer close(w.readerDone)
for {
var env protocol.Envelope
if err := w.codec.Decode(w.bufferedReader, &env); err != nil {
if err == io.EOF {
w.logger.Info("dtls session closed", nil)
} else {
w.logger.Error("failed to decode envelope in read loop", logging.Fields{
"error": err.Error(),
})
}
// Notify all pending requests of the error
w.mu.Lock()
for _, pending := range w.pending {
close(pending.respCh)
}
w.pending = make(map[protocol.StreamID]*pendingRequest)
w.mu.Unlock()
return
}
// Determine the stream ID from the envelope
var streamID protocol.StreamID
switch env.Type {
case protocol.MessageTypeStreamOpen:
if env.StreamOpen != nil {
streamID = env.StreamOpen.ID
}
case protocol.MessageTypeStreamData:
if env.StreamData != nil {
streamID = env.StreamData.ID
}
case protocol.MessageTypeStreamClose:
if env.StreamClose != nil {
streamID = env.StreamClose.ID
}
default:
w.logger.Warn("received unexpected envelope type in read loop", logging.Fields{
"type": env.Type,
})
continue
}
if streamID == "" {
w.logger.Warn("received envelope with empty stream ID", logging.Fields{
"type": env.Type,
})
continue
}
// Find the pending request for this stream ID
w.mu.Lock()
pending := w.pending[streamID]
w.mu.Unlock()
if pending == nil {
w.logger.Warn("received envelope for unknown stream ID", logging.Fields{
"stream_id": streamID,
"type": env.Type,
})
continue
}
// Send the envelope to the waiting request
select {
case pending.respCh <- &env:
// Successfully delivered
case <-pending.doneCh:
// Request was cancelled or timed out
w.logger.Warn("pending request already closed", logging.Fields{
"stream_id": streamID,
})
}
}
}
// ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose // ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose
// frames and reconstructs the reverse stream into a protocol.Response. (en) // frames and reconstructs the reverse stream into a protocol.Response. (en)
// This method now supports concurrent requests by using a channel-based multiplexing approach.
func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) { func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) {
w.mu.Lock()
defer w.mu.Unlock()
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
codec := protocol.DefaultCodec // Generate a unique stream ID (needs mutex for nextStreamID)
w.mu.Lock()
// 세션 내에서 고유한 StreamID 를 생성합니다. (ko)
// Generate a unique StreamID for this HTTP request within the DTLS session. (en)
streamID := w.nextHTTPStreamID() streamID := w.nextHTTPStreamID()
// Create a pending request to receive responses
pending := &pendingRequest{
streamID: streamID,
respCh: make(chan *protocol.Envelope, 16), // Buffered to avoid blocking readLoop
doneCh: make(chan struct{}),
}
w.pending[streamID] = pending
w.mu.Unlock()
// Ensure cleanup on exit
defer func() {
w.mu.Lock()
delete(w.pending, streamID)
w.mu.Unlock()
close(pending.doneCh)
}()
log := logger.With(logging.Fields{ log := logger.With(logging.Fields{
"component": "http_to_dtls", "component": "http_to_dtls",
"request_id": string(streamID), "request_id": string(streamID),
@@ -233,7 +346,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
Header: hdr, Header: hdr,
}, },
} }
if err := codec.Encode(w.sess, openEnv); err != nil { if err := w.codec.Encode(w.sess, openEnv); err != nil {
log.Error("failed to encode stream_open envelope", logging.Fields{ log.Error("failed to encode stream_open envelope", logging.Fields{
"error": err.Error(), "error": err.Error(),
}) })
@@ -257,7 +370,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
Data: dataCopy, Data: dataCopy,
}, },
} }
if err2 := codec.Encode(w.sess, dataEnv); err2 != nil { if err2 := w.codec.Encode(w.sess, dataEnv); err2 != nil {
log.Error("failed to encode stream_data envelope", logging.Fields{ log.Error("failed to encode stream_data envelope", logging.Fields{
"error": err2.Error(), "error": err2.Error(),
}) })
@@ -283,7 +396,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
Error: "", Error: "",
}, },
} }
if err := codec.Encode(w.sess, closeReqEnv); err != nil { if err := w.codec.Encode(w.sess, closeReqEnv); err != nil {
log.Error("failed to encode request stream_close envelope", logging.Fields{ log.Error("failed to encode request stream_close envelope", logging.Fields{
"error": err.Error(), "error": err.Error(),
}) })
@@ -291,7 +404,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
} }
// 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko) // 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko)
// Receive reverse stream response (StreamOpen + StreamData* + StreamClose). (en) // Receive reverse stream response (StreamOpen + StreamData* + StreamClose) via the readLoop. (en)
var ( var (
resp protocol.Response resp protocol.Response
bodyBuf bytes.Buffer bodyBuf bytes.Buffer
@@ -303,79 +416,81 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
resp.Header = make(map[string][]string) resp.Header = make(map[string][]string)
for { for {
var env protocol.Envelope select {
if err := codec.Decode(w.bufferedReader, &env); err != nil { case <-ctx.Done():
log.Error("failed to decode stream response envelope", logging.Fields{ log.Error("context cancelled while waiting for response", logging.Fields{
"error": err.Error(), "error": ctx.Err().Error(),
}) })
return nil, err return nil, ctx.Err()
}
switch env.Type { case <-w.readerDone:
case protocol.MessageTypeStreamOpen: log.Error("dtls session closed while waiting for response", nil)
so := env.StreamOpen return nil, fmt.Errorf("dtls session closed")
if so == nil {
return nil, fmt.Errorf("stream_open response payload is nil") case env, ok := <-pending.respCh:
if !ok {
// Channel closed, session is dead
log.Error("response channel closed unexpectedly", nil)
return nil, fmt.Errorf("response channel closed")
} }
if so.ID != streamID {
return nil, fmt.Errorf("unexpected stream_open for id %q (expected %q)", so.ID, streamID) switch env.Type {
} case protocol.MessageTypeStreamOpen:
// 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko) so := env.StreamOpen
// Restore status code and headers (strip pseudo-headers). (en) if so == nil {
statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK)) return nil, fmt.Errorf("stream_open response payload is nil")
if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 {
statusCode = sc
}
for k, vs := range so.Header {
if k == protocol.HeaderKeyMethod ||
k == protocol.HeaderKeyURL ||
k == protocol.HeaderKeyHost ||
k == protocol.HeaderKeyStatus {
continue
} }
resp.Header[k] = append([]string(nil), vs...) // 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko)
} // Restore status code and headers (strip pseudo-headers). (en)
gotOpen = true statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK))
if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 {
case protocol.MessageTypeStreamData: statusCode = sc
sd := env.StreamData
if sd == nil {
return nil, fmt.Errorf("stream_data response payload is nil")
}
if sd.ID != streamID {
return nil, fmt.Errorf("unexpected stream_data for id %q (expected %q)", sd.ID, streamID)
}
if len(sd.Data) > 0 {
if _, err := bodyBuf.Write(sd.Data); err != nil {
return nil, fmt.Errorf("buffer stream_data response: %w", err)
} }
} for k, vs := range so.Header {
if k == protocol.HeaderKeyMethod ||
k == protocol.HeaderKeyURL ||
k == protocol.HeaderKeyHost ||
k == protocol.HeaderKeyStatus {
continue
}
resp.Header[k] = append([]string(nil), vs...)
}
gotOpen = true
case protocol.MessageTypeStreamClose: case protocol.MessageTypeStreamData:
sc := env.StreamClose sd := env.StreamData
if sc == nil { if sd == nil {
return nil, fmt.Errorf("stream_close response payload is nil") return nil, fmt.Errorf("stream_data response payload is nil")
} }
if sc.ID != streamID { if len(sd.Data) > 0 {
return nil, fmt.Errorf("unexpected stream_close for id %q (expected %q)", sc.ID, streamID) if _, err := bodyBuf.Write(sd.Data); err != nil {
} return nil, fmt.Errorf("buffer stream_data response: %w", err)
// 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko) }
// Stream finished: complete protocol.Response using collected headers/body. (en) }
resp.Status = statusCode
resp.Body = bodyBuf.Bytes()
resp.Error = sc.Error
log.Info("received stream http response over dtls", logging.Fields{ case protocol.MessageTypeStreamClose:
"status": resp.Status, sc := env.StreamClose
"error": resp.Error, if sc == nil {
}) return nil, fmt.Errorf("stream_close response payload is nil")
if !gotOpen { }
return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID) // 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko)
} // Stream finished: complete protocol.Response using collected headers/body. (en)
return &resp, nil resp.Status = statusCode
resp.Body = bodyBuf.Bytes()
resp.Error = sc.Error
default: log.Info("received stream http response over dtls", logging.Fields{
return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type) "status": resp.Status,
"error": resp.Error,
})
if !gotOpen {
return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID)
}
return &resp, nil
default:
return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type)
}
} }
} }
} }
@@ -507,7 +622,15 @@ func registerSessionForDomain(domain string, sess dtls.Session, logger logging.L
w := &dtlsSessionWrapper{ w := &dtlsSessionWrapper{
sess: sess, sess: sess,
bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()), bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()),
codec: protocol.DefaultCodec,
logger: logger.With(logging.Fields{"component": "dtls_session_wrapper", "domain": d}),
pending: make(map[protocol.StreamID]*pendingRequest),
readerDone: make(chan struct{}),
} }
// Start background reader goroutine to demultiplex incoming responses
go w.readLoop()
sessionsMu.Lock() sessionsMu.Lock()
sessionsByDomain[d] = w sessionsByDomain[d] = w
sessionsMu.Unlock() sessionsMu.Unlock()