mirror of
https://github.com/dalbodeule/hop-gate.git
synced 2025-12-12 14:50:09 +09:00
Fix concurrent request handling with stream multiplexing
- Add channel-based multiplexing to handle concurrent HTTP requests - Implement background readLoop to dispatch responses to correct streams - Remove mutex bottleneck that was serializing all requests - Fixes "unexpected stream_data/stream_open for id" errors with concurrent requests Co-authored-by: dalbodeule <11470513+dalbodeule@users.noreply.github.com>
This commit is contained in:
@@ -34,11 +34,31 @@ import (
|
|||||||
// 기본값 "dev" 는 로컬 개발용입니다.
|
// 기본값 "dev" 는 로컬 개발용입니다.
|
||||||
var version = "dev"
|
var version = "dev"
|
||||||
|
|
||||||
|
// streamResponse collects the complete response for a single HTTP stream request
|
||||||
|
type streamResponse struct {
|
||||||
|
statusCode int
|
||||||
|
header map[string][]string
|
||||||
|
body bytes.Buffer
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// pendingRequest tracks a request waiting for its response
|
||||||
|
type pendingRequest struct {
|
||||||
|
streamID protocol.StreamID
|
||||||
|
respCh chan *protocol.Envelope
|
||||||
|
doneCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
type dtlsSessionWrapper struct {
|
type dtlsSessionWrapper struct {
|
||||||
sess dtls.Session
|
sess dtls.Session
|
||||||
bufferedReader *bufio.Reader
|
bufferedReader *bufio.Reader
|
||||||
mu sync.Mutex
|
codec protocol.WireCodec
|
||||||
nextStreamID uint64
|
logger logging.Logger
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
nextStreamID uint64
|
||||||
|
pending map[protocol.StreamID]*pendingRequest
|
||||||
|
readerDone chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getEnvOrPanic(logger logging.Logger, key string) string {
|
func getEnvOrPanic(logger logging.Logger, key string) string {
|
||||||
@@ -176,22 +196,115 @@ func parseExpectedIPsFromEnv(logger logging.Logger, envKey string) []net.IP {
|
|||||||
|
|
||||||
// ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고,
|
// ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고,
|
||||||
// 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko)
|
// 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko)
|
||||||
|
// readLoop continuously reads from the DTLS session and dispatches incoming frames
|
||||||
|
// to the appropriate pending request based on stream ID
|
||||||
|
func (w *dtlsSessionWrapper) readLoop() {
|
||||||
|
defer close(w.readerDone)
|
||||||
|
|
||||||
|
for {
|
||||||
|
var env protocol.Envelope
|
||||||
|
if err := w.codec.Decode(w.bufferedReader, &env); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
w.logger.Info("dtls session closed", nil)
|
||||||
|
} else {
|
||||||
|
w.logger.Error("failed to decode envelope in read loop", logging.Fields{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// Notify all pending requests of the error
|
||||||
|
w.mu.Lock()
|
||||||
|
for _, pending := range w.pending {
|
||||||
|
close(pending.respCh)
|
||||||
|
}
|
||||||
|
w.pending = make(map[protocol.StreamID]*pendingRequest)
|
||||||
|
w.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the stream ID from the envelope
|
||||||
|
var streamID protocol.StreamID
|
||||||
|
switch env.Type {
|
||||||
|
case protocol.MessageTypeStreamOpen:
|
||||||
|
if env.StreamOpen != nil {
|
||||||
|
streamID = env.StreamOpen.ID
|
||||||
|
}
|
||||||
|
case protocol.MessageTypeStreamData:
|
||||||
|
if env.StreamData != nil {
|
||||||
|
streamID = env.StreamData.ID
|
||||||
|
}
|
||||||
|
case protocol.MessageTypeStreamClose:
|
||||||
|
if env.StreamClose != nil {
|
||||||
|
streamID = env.StreamClose.ID
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
w.logger.Warn("received unexpected envelope type in read loop", logging.Fields{
|
||||||
|
"type": env.Type,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if streamID == "" {
|
||||||
|
w.logger.Warn("received envelope with empty stream ID", logging.Fields{
|
||||||
|
"type": env.Type,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the pending request for this stream ID
|
||||||
|
w.mu.Lock()
|
||||||
|
pending := w.pending[streamID]
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
if pending == nil {
|
||||||
|
w.logger.Warn("received envelope for unknown stream ID", logging.Fields{
|
||||||
|
"stream_id": streamID,
|
||||||
|
"type": env.Type,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the envelope to the waiting request
|
||||||
|
select {
|
||||||
|
case pending.respCh <- &env:
|
||||||
|
// Successfully delivered
|
||||||
|
case <-pending.doneCh:
|
||||||
|
// Request was cancelled or timed out
|
||||||
|
w.logger.Warn("pending request already closed", logging.Fields{
|
||||||
|
"stream_id": streamID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose
|
// ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose
|
||||||
// frames and reconstructs the reverse stream into a protocol.Response. (en)
|
// frames and reconstructs the reverse stream into a protocol.Response. (en)
|
||||||
|
// This method now supports concurrent requests by using a channel-based multiplexing approach.
|
||||||
func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) {
|
func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) {
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
codec := protocol.DefaultCodec
|
// Generate a unique stream ID (needs mutex for nextStreamID)
|
||||||
|
w.mu.Lock()
|
||||||
// 세션 내에서 고유한 StreamID 를 생성합니다. (ko)
|
|
||||||
// Generate a unique StreamID for this HTTP request within the DTLS session. (en)
|
|
||||||
streamID := w.nextHTTPStreamID()
|
streamID := w.nextHTTPStreamID()
|
||||||
|
|
||||||
|
// Create a pending request to receive responses
|
||||||
|
pending := &pendingRequest{
|
||||||
|
streamID: streamID,
|
||||||
|
respCh: make(chan *protocol.Envelope, 16), // Buffered to avoid blocking readLoop
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
w.pending[streamID] = pending
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
// Ensure cleanup on exit
|
||||||
|
defer func() {
|
||||||
|
w.mu.Lock()
|
||||||
|
delete(w.pending, streamID)
|
||||||
|
w.mu.Unlock()
|
||||||
|
close(pending.doneCh)
|
||||||
|
}()
|
||||||
|
|
||||||
log := logger.With(logging.Fields{
|
log := logger.With(logging.Fields{
|
||||||
"component": "http_to_dtls",
|
"component": "http_to_dtls",
|
||||||
"request_id": string(streamID),
|
"request_id": string(streamID),
|
||||||
@@ -233,7 +346,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
|||||||
Header: hdr,
|
Header: hdr,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err := codec.Encode(w.sess, openEnv); err != nil {
|
if err := w.codec.Encode(w.sess, openEnv); err != nil {
|
||||||
log.Error("failed to encode stream_open envelope", logging.Fields{
|
log.Error("failed to encode stream_open envelope", logging.Fields{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
@@ -257,7 +370,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
|||||||
Data: dataCopy,
|
Data: dataCopy,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err2 := codec.Encode(w.sess, dataEnv); err2 != nil {
|
if err2 := w.codec.Encode(w.sess, dataEnv); err2 != nil {
|
||||||
log.Error("failed to encode stream_data envelope", logging.Fields{
|
log.Error("failed to encode stream_data envelope", logging.Fields{
|
||||||
"error": err2.Error(),
|
"error": err2.Error(),
|
||||||
})
|
})
|
||||||
@@ -283,7 +396,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
|||||||
Error: "",
|
Error: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err := codec.Encode(w.sess, closeReqEnv); err != nil {
|
if err := w.codec.Encode(w.sess, closeReqEnv); err != nil {
|
||||||
log.Error("failed to encode request stream_close envelope", logging.Fields{
|
log.Error("failed to encode request stream_close envelope", logging.Fields{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
@@ -291,7 +404,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko)
|
// 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko)
|
||||||
// Receive reverse stream response (StreamOpen + StreamData* + StreamClose). (en)
|
// Receive reverse stream response (StreamOpen + StreamData* + StreamClose) via the readLoop. (en)
|
||||||
var (
|
var (
|
||||||
resp protocol.Response
|
resp protocol.Response
|
||||||
bodyBuf bytes.Buffer
|
bodyBuf bytes.Buffer
|
||||||
@@ -303,79 +416,81 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log
|
|||||||
resp.Header = make(map[string][]string)
|
resp.Header = make(map[string][]string)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var env protocol.Envelope
|
select {
|
||||||
if err := codec.Decode(w.bufferedReader, &env); err != nil {
|
case <-ctx.Done():
|
||||||
log.Error("failed to decode stream response envelope", logging.Fields{
|
log.Error("context cancelled while waiting for response", logging.Fields{
|
||||||
"error": err.Error(),
|
"error": ctx.Err().Error(),
|
||||||
})
|
})
|
||||||
return nil, err
|
return nil, ctx.Err()
|
||||||
}
|
|
||||||
|
|
||||||
switch env.Type {
|
case <-w.readerDone:
|
||||||
case protocol.MessageTypeStreamOpen:
|
log.Error("dtls session closed while waiting for response", nil)
|
||||||
so := env.StreamOpen
|
return nil, fmt.Errorf("dtls session closed")
|
||||||
if so == nil {
|
|
||||||
return nil, fmt.Errorf("stream_open response payload is nil")
|
case env, ok := <-pending.respCh:
|
||||||
|
if !ok {
|
||||||
|
// Channel closed, session is dead
|
||||||
|
log.Error("response channel closed unexpectedly", nil)
|
||||||
|
return nil, fmt.Errorf("response channel closed")
|
||||||
}
|
}
|
||||||
if so.ID != streamID {
|
|
||||||
return nil, fmt.Errorf("unexpected stream_open for id %q (expected %q)", so.ID, streamID)
|
switch env.Type {
|
||||||
}
|
case protocol.MessageTypeStreamOpen:
|
||||||
// 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko)
|
so := env.StreamOpen
|
||||||
// Restore status code and headers (strip pseudo-headers). (en)
|
if so == nil {
|
||||||
statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK))
|
return nil, fmt.Errorf("stream_open response payload is nil")
|
||||||
if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 {
|
|
||||||
statusCode = sc
|
|
||||||
}
|
|
||||||
for k, vs := range so.Header {
|
|
||||||
if k == protocol.HeaderKeyMethod ||
|
|
||||||
k == protocol.HeaderKeyURL ||
|
|
||||||
k == protocol.HeaderKeyHost ||
|
|
||||||
k == protocol.HeaderKeyStatus {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
resp.Header[k] = append([]string(nil), vs...)
|
// 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko)
|
||||||
}
|
// Restore status code and headers (strip pseudo-headers). (en)
|
||||||
gotOpen = true
|
statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK))
|
||||||
|
if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 {
|
||||||
case protocol.MessageTypeStreamData:
|
statusCode = sc
|
||||||
sd := env.StreamData
|
|
||||||
if sd == nil {
|
|
||||||
return nil, fmt.Errorf("stream_data response payload is nil")
|
|
||||||
}
|
|
||||||
if sd.ID != streamID {
|
|
||||||
return nil, fmt.Errorf("unexpected stream_data for id %q (expected %q)", sd.ID, streamID)
|
|
||||||
}
|
|
||||||
if len(sd.Data) > 0 {
|
|
||||||
if _, err := bodyBuf.Write(sd.Data); err != nil {
|
|
||||||
return nil, fmt.Errorf("buffer stream_data response: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
for k, vs := range so.Header {
|
||||||
|
if k == protocol.HeaderKeyMethod ||
|
||||||
|
k == protocol.HeaderKeyURL ||
|
||||||
|
k == protocol.HeaderKeyHost ||
|
||||||
|
k == protocol.HeaderKeyStatus {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp.Header[k] = append([]string(nil), vs...)
|
||||||
|
}
|
||||||
|
gotOpen = true
|
||||||
|
|
||||||
case protocol.MessageTypeStreamClose:
|
case protocol.MessageTypeStreamData:
|
||||||
sc := env.StreamClose
|
sd := env.StreamData
|
||||||
if sc == nil {
|
if sd == nil {
|
||||||
return nil, fmt.Errorf("stream_close response payload is nil")
|
return nil, fmt.Errorf("stream_data response payload is nil")
|
||||||
}
|
}
|
||||||
if sc.ID != streamID {
|
if len(sd.Data) > 0 {
|
||||||
return nil, fmt.Errorf("unexpected stream_close for id %q (expected %q)", sc.ID, streamID)
|
if _, err := bodyBuf.Write(sd.Data); err != nil {
|
||||||
}
|
return nil, fmt.Errorf("buffer stream_data response: %w", err)
|
||||||
// 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko)
|
}
|
||||||
// Stream finished: complete protocol.Response using collected headers/body. (en)
|
}
|
||||||
resp.Status = statusCode
|
|
||||||
resp.Body = bodyBuf.Bytes()
|
|
||||||
resp.Error = sc.Error
|
|
||||||
|
|
||||||
log.Info("received stream http response over dtls", logging.Fields{
|
case protocol.MessageTypeStreamClose:
|
||||||
"status": resp.Status,
|
sc := env.StreamClose
|
||||||
"error": resp.Error,
|
if sc == nil {
|
||||||
})
|
return nil, fmt.Errorf("stream_close response payload is nil")
|
||||||
if !gotOpen {
|
}
|
||||||
return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID)
|
// 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko)
|
||||||
}
|
// Stream finished: complete protocol.Response using collected headers/body. (en)
|
||||||
return &resp, nil
|
resp.Status = statusCode
|
||||||
|
resp.Body = bodyBuf.Bytes()
|
||||||
|
resp.Error = sc.Error
|
||||||
|
|
||||||
default:
|
log.Info("received stream http response over dtls", logging.Fields{
|
||||||
return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type)
|
"status": resp.Status,
|
||||||
|
"error": resp.Error,
|
||||||
|
})
|
||||||
|
if !gotOpen {
|
||||||
|
return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID)
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -507,7 +622,15 @@ func registerSessionForDomain(domain string, sess dtls.Session, logger logging.L
|
|||||||
w := &dtlsSessionWrapper{
|
w := &dtlsSessionWrapper{
|
||||||
sess: sess,
|
sess: sess,
|
||||||
bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()),
|
bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()),
|
||||||
|
codec: protocol.DefaultCodec,
|
||||||
|
logger: logger.With(logging.Fields{"component": "dtls_session_wrapper", "domain": d}),
|
||||||
|
pending: make(map[protocol.StreamID]*pendingRequest),
|
||||||
|
readerDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start background reader goroutine to demultiplex incoming responses
|
||||||
|
go w.readLoop()
|
||||||
|
|
||||||
sessionsMu.Lock()
|
sessionsMu.Lock()
|
||||||
sessionsByDomain[d] = w
|
sessionsByDomain[d] = w
|
||||||
sessionsMu.Unlock()
|
sessionsMu.Unlock()
|
||||||
|
|||||||
Reference in New Issue
Block a user