diff --git a/cmd/server/main.go b/cmd/server/main.go index 117a492..18a25c3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -34,11 +34,31 @@ import ( // 기본값 "dev" 는 로컬 개발용입니다. var version = "dev" +// streamResponse collects the complete response for a single HTTP stream request +type streamResponse struct { + statusCode int + header map[string][]string + body bytes.Buffer + err error +} + +// pendingRequest tracks a request waiting for its response +type pendingRequest struct { + streamID protocol.StreamID + respCh chan *protocol.Envelope + doneCh chan struct{} +} + type dtlsSessionWrapper struct { sess dtls.Session bufferedReader *bufio.Reader - mu sync.Mutex - nextStreamID uint64 + codec protocol.WireCodec + logger logging.Logger + + mu sync.Mutex + nextStreamID uint64 + pending map[protocol.StreamID]*pendingRequest + readerDone chan struct{} } func getEnvOrPanic(logger logging.Logger, key string) string { @@ -176,21 +196,114 @@ func parseExpectedIPsFromEnv(logger logging.Logger, envKey string) []net.IP { // ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고, // 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko) +// readLoop continuously reads from the DTLS session and dispatches incoming frames +// to the appropriate pending request based on stream ID +func (w *dtlsSessionWrapper) readLoop() { + defer close(w.readerDone) + + for { + var env protocol.Envelope + if err := w.codec.Decode(w.bufferedReader, &env); err != nil { + if err == io.EOF { + w.logger.Info("dtls session closed", nil) + } else { + w.logger.Error("failed to decode envelope in read loop", logging.Fields{ + "error": err.Error(), + }) + } + // Notify all pending requests of the error + w.mu.Lock() + for _, pending := range w.pending { + close(pending.respCh) + } + w.pending = make(map[protocol.StreamID]*pendingRequest) + w.mu.Unlock() + return + } + + // Determine the stream ID from the envelope + var streamID protocol.StreamID + switch env.Type { + case protocol.MessageTypeStreamOpen: + if env.StreamOpen != nil { + streamID = env.StreamOpen.ID + } + case protocol.MessageTypeStreamData: + if env.StreamData != nil { + streamID = env.StreamData.ID + } + case protocol.MessageTypeStreamClose: + if env.StreamClose != nil { + streamID = env.StreamClose.ID + } + default: + w.logger.Warn("received unexpected envelope type in read loop", logging.Fields{ + "type": env.Type, + }) + continue + } + + if streamID == "" { + w.logger.Warn("received envelope with empty stream ID", logging.Fields{ + "type": env.Type, + }) + continue + } + + // Find the pending request for this stream ID + w.mu.Lock() + pending := w.pending[streamID] + w.mu.Unlock() + + if pending == nil { + w.logger.Warn("received envelope for unknown stream ID", logging.Fields{ + "stream_id": streamID, + "type": env.Type, + }) + continue + } + + // Send the envelope to the waiting request + select { + case pending.respCh <- &env: + // Successfully delivered + case <-pending.doneCh: + // Request was cancelled or timed out + w.logger.Warn("pending request already closed", logging.Fields{ + "stream_id": streamID, + }) + } + } +} + // ForwardHTTP forwards an HTTP request over the DTLS session using StreamOpen/StreamData/StreamClose // frames and reconstructs the reverse stream into a protocol.Response. (en) +// This method now supports concurrent requests by using a channel-based multiplexing approach. func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Logger, req *http.Request, serviceName string) (*protocol.Response, error) { - w.mu.Lock() - defer w.mu.Unlock() - if ctx == nil { ctx = context.Background() } - codec := protocol.DefaultCodec - - // 세션 내에서 고유한 StreamID 를 생성합니다. (ko) - // Generate a unique StreamID for this HTTP request within the DTLS session. (en) + // Generate a unique stream ID (needs mutex for nextStreamID) + w.mu.Lock() streamID := w.nextHTTPStreamID() + + // Create a pending request to receive responses + pending := &pendingRequest{ + streamID: streamID, + respCh: make(chan *protocol.Envelope, 16), // Buffered to avoid blocking readLoop + doneCh: make(chan struct{}), + } + w.pending[streamID] = pending + w.mu.Unlock() + + // Ensure cleanup on exit + defer func() { + w.mu.Lock() + delete(w.pending, streamID) + w.mu.Unlock() + close(pending.doneCh) + }() log := logger.With(logging.Fields{ "component": "http_to_dtls", @@ -233,7 +346,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log Header: hdr, }, } - if err := codec.Encode(w.sess, openEnv); err != nil { + if err := w.codec.Encode(w.sess, openEnv); err != nil { log.Error("failed to encode stream_open envelope", logging.Fields{ "error": err.Error(), }) @@ -257,7 +370,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log Data: dataCopy, }, } - if err2 := codec.Encode(w.sess, dataEnv); err2 != nil { + if err2 := w.codec.Encode(w.sess, dataEnv); err2 != nil { log.Error("failed to encode stream_data envelope", logging.Fields{ "error": err2.Error(), }) @@ -283,7 +396,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log Error: "", }, } - if err := codec.Encode(w.sess, closeReqEnv); err != nil { + if err := w.codec.Encode(w.sess, closeReqEnv); err != nil { log.Error("failed to encode request stream_close envelope", logging.Fields{ "error": err.Error(), }) @@ -291,7 +404,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log } // 클라이언트로부터 역방향 스트림 응답을 수신합니다. (ko) - // Receive reverse stream response (StreamOpen + StreamData* + StreamClose). (en) + // Receive reverse stream response (StreamOpen + StreamData* + StreamClose) via the readLoop. (en) var ( resp protocol.Response bodyBuf bytes.Buffer @@ -303,79 +416,81 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log resp.Header = make(map[string][]string) for { - var env protocol.Envelope - if err := codec.Decode(w.bufferedReader, &env); err != nil { - log.Error("failed to decode stream response envelope", logging.Fields{ - "error": err.Error(), + select { + case <-ctx.Done(): + log.Error("context cancelled while waiting for response", logging.Fields{ + "error": ctx.Err().Error(), }) - return nil, err - } + return nil, ctx.Err() - switch env.Type { - case protocol.MessageTypeStreamOpen: - so := env.StreamOpen - if so == nil { - return nil, fmt.Errorf("stream_open response payload is nil") + case <-w.readerDone: + log.Error("dtls session closed while waiting for response", nil) + return nil, fmt.Errorf("dtls session closed") + + case env, ok := <-pending.respCh: + if !ok { + // Channel closed, session is dead + log.Error("response channel closed unexpectedly", nil) + return nil, fmt.Errorf("response channel closed") } - if so.ID != streamID { - return nil, fmt.Errorf("unexpected stream_open for id %q (expected %q)", so.ID, streamID) - } - // 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko) - // Restore status code and headers (strip pseudo-headers). (en) - statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK)) - if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 { - statusCode = sc - } - for k, vs := range so.Header { - if k == protocol.HeaderKeyMethod || - k == protocol.HeaderKeyURL || - k == protocol.HeaderKeyHost || - k == protocol.HeaderKeyStatus { - continue + + switch env.Type { + case protocol.MessageTypeStreamOpen: + so := env.StreamOpen + if so == nil { + return nil, fmt.Errorf("stream_open response payload is nil") } - resp.Header[k] = append([]string(nil), vs...) - } - gotOpen = true - - case protocol.MessageTypeStreamData: - sd := env.StreamData - if sd == nil { - return nil, fmt.Errorf("stream_data response payload is nil") - } - if sd.ID != streamID { - return nil, fmt.Errorf("unexpected stream_data for id %q (expected %q)", sd.ID, streamID) - } - if len(sd.Data) > 0 { - if _, err := bodyBuf.Write(sd.Data); err != nil { - return nil, fmt.Errorf("buffer stream_data response: %w", err) + // 상태 코드 및 헤더 복원 (pseudo-header 제거). (ko) + // Restore status code and headers (strip pseudo-headers). (en) + statusStr := firstHeaderValue(so.Header, protocol.HeaderKeyStatus, strconv.Itoa(http.StatusOK)) + if sc, err := strconv.Atoi(statusStr); err == nil && sc > 0 { + statusCode = sc } - } + for k, vs := range so.Header { + if k == protocol.HeaderKeyMethod || + k == protocol.HeaderKeyURL || + k == protocol.HeaderKeyHost || + k == protocol.HeaderKeyStatus { + continue + } + resp.Header[k] = append([]string(nil), vs...) + } + gotOpen = true - case protocol.MessageTypeStreamClose: - sc := env.StreamClose - if sc == nil { - return nil, fmt.Errorf("stream_close response payload is nil") - } - if sc.ID != streamID { - return nil, fmt.Errorf("unexpected stream_close for id %q (expected %q)", sc.ID, streamID) - } - // 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko) - // Stream finished: complete protocol.Response using collected headers/body. (en) - resp.Status = statusCode - resp.Body = bodyBuf.Bytes() - resp.Error = sc.Error + case protocol.MessageTypeStreamData: + sd := env.StreamData + if sd == nil { + return nil, fmt.Errorf("stream_data response payload is nil") + } + if len(sd.Data) > 0 { + if _, err := bodyBuf.Write(sd.Data); err != nil { + return nil, fmt.Errorf("buffer stream_data response: %w", err) + } + } - log.Info("received stream http response over dtls", logging.Fields{ - "status": resp.Status, - "error": resp.Error, - }) - if !gotOpen { - return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID) - } - return &resp, nil + case protocol.MessageTypeStreamClose: + sc := env.StreamClose + if sc == nil { + return nil, fmt.Errorf("stream_close response payload is nil") + } + // 스트림 종료: 지금까지 수신한 헤더/바디로 protocol.Response 를 완성합니다. (ko) + // Stream finished: complete protocol.Response using collected headers/body. (en) + resp.Status = statusCode + resp.Body = bodyBuf.Bytes() + resp.Error = sc.Error - default: - return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type) + log.Info("received stream http response over dtls", logging.Fields{ + "status": resp.Status, + "error": resp.Error, + }) + if !gotOpen { + return nil, fmt.Errorf("received stream_close without prior stream_open for stream %q", streamID) + } + return &resp, nil + + default: + return nil, fmt.Errorf("unexpected envelope type %q in stream response", env.Type) + } } } } @@ -507,7 +622,15 @@ func registerSessionForDomain(domain string, sess dtls.Session, logger logging.L w := &dtlsSessionWrapper{ sess: sess, bufferedReader: bufio.NewReaderSize(sess, protocol.GetDTLSReadBufferSize()), + codec: protocol.DefaultCodec, + logger: logger.With(logging.Fields{"component": "dtls_session_wrapper", "domain": d}), + pending: make(map[protocol.StreamID]*pendingRequest), + readerDone: make(chan struct{}), } + + // Start background reader goroutine to demultiplex incoming responses + go w.readLoop() + sessionsMu.Lock() sessionsByDomain[d] = w sessionsMu.Unlock()