mirror of
https://github.com/dalbodeule/hop-gate.git
synced 2025-12-12 06:40:11 +09:00
Fix concurrent request handling with stream multiplexing
- Add channel-based multiplexing to handle concurrent HTTP requests - Implement background readLoop to dispatch responses to correct streams - Remove mutex bottleneck that was serializing all requests - Fixes "unexpected stream_data/stream_open for id" errors with concurrent requests Co-authored-by: dalbodeule <11470513+dalbodeule@users.noreply.github.com>
This commit is contained in:
@@ -34,11 +34,31 @@ import (
|
||||
// 기본값 "dev" 는 로컬 개발용입니다.
|
||||
var version = "dev"
|
||||
|
||||
// streamResponse collects the complete response for a single HTTP stream request
|
||||
type streamResponse struct {
|
||||
statusCode int
|
||||
header map[string][]string
|
||||
body bytes.Buffer
|
||||
err error
|
||||
}
|
||||
|
||||
// pendingRequest tracks a request waiting for its response
|
||||
type pendingRequest struct {
|
||||
streamID protocol.StreamID
|
||||
respCh chan *protocol.Envelope
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type dtlsSessionWrapper struct {
|
||||
sess dtls.Session
|
||||
bufferedReader *bufio.Reader
|
||||
codec protocol.WireCodec
|
||||
logger logging.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
nextStreamID uint64
|
||||
pending map[protocol.StreamID]*pendingRequest
|
||||
readerDone chan struct{}
|
||||
}
|
||||
|
||||
func getEnvOrPanic(logger logging.Logger, key string) string {
|
||||
@@ -176,22 +196,115 @@ func parseExpectedIPsFromEnv(logger logging.Logger, envKey string) []net.IP {
|
||||
|
||||
// ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고,
|
||||
// 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko)
|
||||
// readLoop continuously reads from the DTLS session and dispatches incoming frames
|
||||
// to the appropriate pending request based on stream ID
|
||||
func (w *dtlsSessionWrapper) readLoop() {
|
||||
defer close(w.readerDone)
|
||||
|
||||
for {
|
||||
var env protocol.Envelope
|
||||
if err := w.codec.Decode(w.bufferedReader, &env); err != nil {
|
||||
if err == io.EOF {
|
||||
w.logger.Info("dtls session closed", nil)
|
||||
} else {
|
||||
w.logger.Error("failed to decode envelope in read loop", logging.Fields{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
// Notify all pending requests of the error
|
||||
w.mu.Lock()
|
||||
for _, pending := range w.pending {
|
||||
close(pending.respCh)
|
||||
}
|
||||
w.pending = make(map[protocol.StreamID]*pendingRequest)
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the stream ID from the envelope
|
||||
var streamID protocol.StreamID
|
||||
switch env.Type {
|
||||
case protocol.MessageTypeStreamOpen:
|
||||
if env.StreamOpen != nil {
|
||||
streamID = env.StreamOpen.ID
|
||||
}
|
||||
case protocol.MessageTypeStreamData:
|
||||
if env.StreamData != nil {
|
||||
streamID = env.StreamData.ID
|
||||
}
|
||||
case protocol.MessageTypeStreamClose:
|
||||
if env.StreamClose != nil {
|
||||
streamID = env.StreamClose.ID
|
||||
}
|
||||
default:
|
||||
w.logger.Warn("received unexpected envelope type in read loop", logging.Fields{
|
||||
"type": env.Type,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if streamID == "" {
|
||||
w.logger.Warn("received envelope with empty stream ID", logging.Fields{
|
||||
"type": env.Type,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the pending request for this stream ID
|
||||
w.mu.Lock()
|
||||
pending := w.pending[streamID]
|
||||
w.mu.Unlock()
|
||||
|
||||
if pending == nil {
|
||||
w.logger.Warn("received envelope for unknown stream ID", logging.Fields{
|
||||
"stream_id": streamID,
|
||||
"type": env.Type,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Send the envelope to the waiting request
|
||||
select {
|
||||
case pending.respCh <- &env:
|
||||
// Successfully delivered
|
||||
case <-pending.doneCh:
|
||||
// Request was cancelled or timed out
|
||||
w.logger.Warn("pending request already closed", logging.Fields{
|
||||
"stream_id": streamID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose
|
||||
// frames and reconstructs the reverse stream into a protocol.Response. (en)
|
||||
// This method now supports concurrent requests by using a channel-based multiplexing approach.
|
||||
func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
codec := protocol.DefaultCodec
|
||||
|
||||
// 세션 내에서 고유한 StreamID 를 생성합니다. (ko)
|
||||
// Generate a unique StreamID for this HTTP request within the DTLS session. (en)
|
||||
// Generate a unique stream ID (needs mutex for nextStreamID)
|
||||
w.mu.Lock()
|
||||
streamID := w.nextHTTPStreamID()
|
||||
|
||||
// Create a pending request to receive responses
|
||||
pending := &pendingRequest{
|
||||
streamID: streamID,
|
||||
respCh: make(chan *protocol.Envelope, 16), // Buffered to avoid blocking readLoop
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
w.pending[streamID] = pending
|
||||
w.mu.Unlock()
|
||||
|
||||
// Ensure cleanup on exit
|
||||
defer func() {
|
||||
w.mu.Lock()
|
||||
delete(w.pending, streamID)
|
||||
w.mu.Unlock()
|
||||
close(pending.doneCh)
|
||||
}()
|
||||
|
||||
log := logger.With(logging.Fields{
|
||||
"component": "http_to_dtls",
|
||||
"request_id": string(streamID),
|
||||
@@ -233,7 +346,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
Header: hdr,
|
||||
},
|
||||
}
|
||||
if err := codec.Encode(w.sess, openEnv); err != nil {
|
||||
if err := w.codec.Encode(w.sess, openEnv); err != nil {
|
||||
log.Error("failed to encode stream_open envelope", logging.Fields{
|
||||
"error": err.Error(),
|
||||
})
|
||||
@@ -257,7 +370,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
Data: dataCopy,
|
||||
},
|
||||
}
|
||||
if err2 := codec.Encode(w.sess, dataEnv); err2 != nil {
|
||||
if err2 := w.codec.Encode(w.sess, dataEnv); err2 != nil {
|
||||
log.Error("failed to encode stream_data envelope", logging.Fields{
|
||||
"error": err2.Error(),
|
||||
})
|
||||
@@ -283,7 +396,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
Error: "",
|
||||
},
|
||||
}
|
||||
if err := codec.Encode(w.sess, closeReqEnv); err != nil {
|
||||
if err := w.codec.Encode(w.sess, closeReqEnv); err != nil {
|
||||
log.Error("failed to encode request stream_close envelope", logging.Fields{
|
||||
"error": err.Error(),
|
||||
})
|
||||
@@ -291,7 +404,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
}
|
||||
|
||||
// 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko)
|
||||
// Receive reverse stream response (StreamOpen + StreamData* + StreamClose). (en)
|
||||
// Receive reverse stream response (StreamOpen + StreamData* + StreamClose) via the readLoop. (en)
|
||||
var (
|
||||
resp protocol.Response
|
||||
bodyBuf bytes.Buffer
|
||||
@@ -303,12 +416,22 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
resp.Header = make(map[string][]string)
|
||||
|
||||
for {
|
||||
var env protocol.Envelope
|
||||
if err := codec.Decode(w.bufferedReader, &env); err != nil {
|
||||
log.Error("failed to decode stream response envelope", logging.Fields{
|
||||
"error": err.Error(),
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error("context cancelled while waiting for response", logging.Fields{
|
||||
"error": ctx.Err().Error(),
|
||||
})
|
||||
return nil, err
|
||||
return nil, ctx.Err()
|
||||
|
||||
case <-w.readerDone:
|
||||
log.Error("dtls session closed while waiting for response", nil)
|
||||
return nil, fmt.Errorf("dtls session closed")
|
||||
|
||||
case env, ok := <-pending.respCh:
|
||||
if !ok {
|
||||
// Channel closed, session is dead
|
||||
log.Error("response channel closed unexpectedly", nil)
|
||||
return nil, fmt.Errorf("response channel closed")
|
||||
}
|
||||
|
||||
switch env.Type {
|
||||
@@ -317,9 +440,6 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
if so == nil {
|
||||
return nil, fmt.Errorf("stream_open response payload is nil")
|
||||
}
|
||||
if so.ID != streamID {
|
||||
return nil, fmt.Errorf("unexpected stream_open for id %q (expected %q)", so.ID, streamID)
|
||||
}
|
||||
// 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko)
|
||||
// Restore status code and headers (strip pseudo-headers). (en)
|
||||
statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK))
|
||||
@@ -342,9 +462,6 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
if sd == nil {
|
||||
return nil, fmt.Errorf("stream_data response payload is nil")
|
||||
}
|
||||
if sd.ID != streamID {
|
||||
return nil, fmt.Errorf("unexpected stream_data for id %q (expected %q)", sd.ID, streamID)
|
||||
}
|
||||
if len(sd.Data) > 0 {
|
||||
if _, err := bodyBuf.Write(sd.Data); err != nil {
|
||||
return nil, fmt.Errorf("buffer stream_data response: %w", err)
|
||||
@@ -356,9 +473,6 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
if sc == nil {
|
||||
return nil, fmt.Errorf("stream_close response payload is nil")
|
||||
}
|
||||
if sc.ID != streamID {
|
||||
return nil, fmt.Errorf("unexpected stream_close for id %q (expected %q)", sc.ID, streamID)
|
||||
}
|
||||
// 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko)
|
||||
// Stream finished: complete protocol.Response using collected headers/body. (en)
|
||||
resp.Status = statusCode
|
||||
@@ -379,6 +493,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextHTTPStreamID 는 DTLS 세션 내 HTTP 요청에 사용할 고유 StreamID 를 생성합니다. (ko)
|
||||
// nextHTTPStreamID generates a unique StreamID for HTTP requests on this DTLS session. (en)
|
||||
@@ -507,7 +622,15 @@ func registerSessionForDomain(domain string, sess dtls.Session, logger logging.L
|
||||
w := &dtlsSessionWrapper{
|
||||
sess: sess,
|
||||
bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()),
|
||||
codec: protocol.DefaultCodec,
|
||||
logger: logger.With(logging.Fields{"component": "dtls_session_wrapper", "domain": d}),
|
||||
pending: make(map[protocol.StreamID]*pendingRequest),
|
||||
readerDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start background reader goroutine to demultiplex incoming responses
|
||||
go w.readLoop()
|
||||
|
||||
sessionsMu.Lock()
|
||||
sessionsByDomain[d] = w
|
||||
sessionsMu.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user