diff --git a/cmd/server/main.go b/cmd/server/main.go index 241a99b..0a89d85 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "path/filepath" + "sort" "strconv" "strings" "sync" @@ -41,6 +42,72 @@ type pendingRequest struct { doneCh chan struct{} } +// streamSender 는 특정 스트림에 대해 전송한 StreamData 프레임의 payload 를 +// 시퀀스 번호별로 보관하여, peer 로부터의 StreamAck 를 기반으로 선택적 재전송을 +// 수행하기 위한 송신 측 ARQ 상태를 나타냅니다. (ko) +// streamSender keeps outstanding StreamData payloads per sequence number so that +// they can be selectively retransmitted based on StreamAck from the peer. (en) +type streamSender struct { + mu sync.Mutex + outstanding map[uint64][]byte +} + +func newStreamSender() *streamSender { + return &streamSender{ + outstanding: make(map[uint64][]byte), + } +} + +func (s *streamSender) register(seq uint64, data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.outstanding == nil { + s.outstanding = make(map[uint64][]byte) + } + buf := make([]byte, len(data)) + copy(buf, data) + s.outstanding[seq] = buf +} + +// handleAck 는 주어진 StreamAck 를 적용하여 AckSeq 이하의 프레임을 정리하고, +// LostSeqs 중 아직 outstanding 에 남아 있는 시퀀스의 payload 를 복사하여 +// 재전송 대상 목록으로 반환합니다. (ko) +// handleAck applies the given StreamAck, removes frames up to AckSeq, and +// returns copies of payloads for LostSeqs that are still outstanding so that +// they can be retransmitted. (en) +func (s *streamSender) handleAck(ack *protocol.StreamAck) map[uint64][]byte { + s.mu.Lock() + defer s.mu.Unlock() + + if s.outstanding == nil { + return nil + } + + // 연속 수신 완료 구간(seq <= AckSeq)은 outstanding 에서 제거합니다. + for seq := range s.outstanding { + if seq <= ack.AckSeq { + delete(s.outstanding, seq) + } + } + + // LostSeqs 가 비어 있으면 재전송할 것이 없습니다. + if len(ack.LostSeqs) == 0 { + return nil + } + + // LostSeqs 중 아직 outstanding 에 남아 있는 것만 재전송 대상으로 선택합니다. + lost := make(map[uint64][]byte, len(ack.LostSeqs)) + for _, seq := range ack.LostSeqs { + if data, ok := s.outstanding[seq]; ok { + buf := make([]byte, len(data)) + copy(buf, data) + lost[seq] = buf + } + } + return lost +} + type dtlsSessionWrapper struct { sess dtls.Session bufferedReader *bufio.Reader @@ -51,6 +118,48 @@ type dtlsSessionWrapper struct { nextStreamID uint64 pending map[protocol.StreamID]*pendingRequest readerDone chan struct{} + + // streamSenders 는 서버 → 클라이언트 방향 HTTP 요청 바디 전송에 대한 + // 송신 측 ARQ 상태를 보관합니다. (ko) + // streamSenders keeps ARQ sender state for HTTP request bodies sent + // from server to client. (en) + streamSenders map[protocol.StreamID]*streamSender +} + +// registerStreamSender 는 주어진 스트림 ID 에 대한 송신 측 ARQ 상태를 등록합니다. (ko) +// registerStreamSender registers the sender-side ARQ state for a given stream ID. (en) +func (w *dtlsSessionWrapper) registerStreamSender(id protocol.StreamID, sender *streamSender) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.streamSenders == nil { + w.streamSenders = make(map[protocol.StreamID]*streamSender) + } + w.streamSenders[id] = sender +} + +// unregisterStreamSender 는 더 이상 사용하지 않는 스트림 ID 에 대한 송신 측 ARQ 상태를 제거합니다. (ko) +// unregisterStreamSender removes the sender-side ARQ state for a stream ID that is no longer used. (en) +func (w *dtlsSessionWrapper) unregisterStreamSender(id protocol.StreamID) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.streamSenders == nil { + return + } + delete(w.streamSenders, id) +} + +// getStreamSender 는 주어진 스트림 ID 에 대한 송신 측 ARQ 상태를 반환합니다. (ko) +// getStreamSender returns the sender-side ARQ state for the given stream ID, if any. (en) +func (w *dtlsSessionWrapper) getStreamSender(id protocol.StreamID) *streamSender { + w.mu.Lock() + defer w.mu.Unlock() + + if w.streamSenders == nil { + return nil + } + return w.streamSenders[id] } func getEnvOrPanic(logger logging.Logger, key string) string { @@ -189,7 +298,8 @@ func parseExpectedIPsFromEnv(logger logging.Logger, envKey string) []net.IP { // ForwardHTTP 는 HTTP 요청을 DTLS 세션 위의 StreamOpen/StreamData/StreamClose 프레임으로 전송하고, // 역방향 스트림 응답을 수신해 protocol.Response 로 반환합니다. (ko) // readLoop continuously reads from the DTLS session and dispatches incoming frames -// to the appropriate pending request based on stream ID +// to the appropriate pending request based on stream ID. It also handles +// application-level ARQ (StreamAck) for request bodies sent from server to client. (en) func (w *dtlsSessionWrapper) readLoop() { defer close(w.readerDone) @@ -203,8 +313,8 @@ func (w *dtlsSessionWrapper) readLoop() { "error": err.Error(), }) } - // Notify all pending requests of the error by closing their response channels - // The doneCh will be closed by each ForwardHTTP's defer + // Notify all pending requests of the error by closing their response channels. + // The doneCh will be closed by each ForwardHTTP's defer. w.mu.Lock() for _, pending := range w.pending { close(pending.respCh) @@ -214,7 +324,53 @@ func (w *dtlsSessionWrapper) readLoop() { return } - // Determine the stream ID from the envelope + // 1) StreamAck 처리: 서버 → 클라이언트 방향 요청 바디 전송에 대한 ARQ. (ko) + // 1) Handle StreamAck: application-level ARQ for request bodies + // sent from server to client. (en) + if env.Type == protocol.MessageTypeStreamAck { + sa := env.StreamAck + if sa == nil { + w.logger.Warn("received stream_ack envelope with nil payload", logging.Fields{}) + continue + } + streamID := sa.ID + sender := w.getStreamSender(streamID) + if sender == nil { + w.logger.Warn("received stream_ack for unknown stream ID", logging.Fields{ + "stream_id": streamID, + }) + continue + } + lost := sender.handleAck(sa) + for seq, data := range lost { + retryEnv := protocol.Envelope{ + Type: protocol.MessageTypeStreamData, + StreamData: &protocol.StreamData{ + ID: streamID, + Seq: seq, + Data: data, + }, + } + if err := w.codec.Encode(w.sess, &retryEnv); err != nil { + w.logger.Error("failed to retransmit stream_data after stream_ack", logging.Fields{ + "stream_id": streamID, + "seq": seq, + "error": err.Error(), + }) + // 세션 쓰기 오류가 발생하면 루프를 종료하여 상위에서 세션 종료를 유도합니다. (ko) + // On write error, stop the loop so that the caller can tear down the session. (en) + return + } + } + // StreamAck 는 애플리케이션 페이로드를 포함하지 않으므로 pending 에 전달하지 않습니다. (ko) + // StreamAck carries no application payload, so it is not forwarded to pending requests. (en) + continue + } + + // 2) StreamOpen / StreamData / StreamClose 에 대해 stream ID 를 산출하고, + // 해당 pending 요청으로 전달합니다. (ko) + // 2) For StreamOpen / StreamData / StreamClose, determine the stream ID + // and forward to the corresponding pending request. (en) var streamID protocol.StreamID switch env.Type { case protocol.MessageTypeStreamOpen: @@ -286,7 +442,7 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log // Generate a unique stream ID (needs mutex for nextStreamID) w.mu.Lock() streamID := w.nextHTTPStreamID() - + // Channel buffer size for response frames to avoid blocking readLoop. // A typical HTTP response has: 1 StreamOpen + N StreamData + 1 StreamClose frames. // With 4KB chunks, even large responses stay within this buffer. @@ -301,12 +457,18 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log w.pending[streamID] = pending w.mu.Unlock() + // 서버 → 클라이언트 방향 요청 바디 전송에 대한 송신 측 ARQ 상태를 준비합니다. (ko) + // Prepare ARQ sender state for the request body sent from server to client. (en) + sender := newStreamSender() + w.registerStreamSender(streamID, sender) + // Ensure cleanup on exit defer func() { w.mu.Lock() delete(w.pending, streamID) w.mu.Unlock() close(pending.doneCh) + w.unregisterStreamSender(streamID) }() log := logger.With(logging.Fields{ @@ -366,6 +528,10 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log n, err := req.Body.Read(buf) if n > 0 { dataCopy := append([]byte(nil), buf[:n]...) + // 송신 측 ARQ: Seq 별 payload 를 기록해 두었다가, 클라이언트의 StreamAck 를 기반으로 재전송합니다. (ko) + // Sender-side ARQ: record payload per Seq so it can be retransmitted based on StreamAck from the client. (en) + sender.register(seq, dataCopy) + dataEnv := &protocol.Envelope{ Type: protocol.MessageTypeStreamData, StreamData: &protocol.StreamData{ @@ -414,7 +580,14 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log bodyBuf bytes.Buffer gotOpen bool statusCode = http.StatusOK + + // 응답 바디(클라이언트 → 서버)에 대한 수신 측 ARQ 상태입니다. (ko) + // ARQ receiver state for the response body (client → server). (en) + expectedSeq uint64 + received = make(map[uint64][]byte) + lost = make(map[uint64]struct{}) ) + const maxLostReport = 32 resp.RequestID = string(streamID) resp.Header = make(map[string][]string) @@ -466,10 +639,94 @@ func (w *dtlsSessionWrapper) ForwardHTTP(ctx context.Context, logger logging.Log if sd == nil { return nil, fmt.Errorf("stream_data response payload is nil") } - if len(sd.Data) > 0 { - if _, err := bodyBuf.Write(sd.Data); err != nil { - return nil, fmt.Errorf("buffer stream_data response: %w", err) + + // 수신 측 ARQ: Seq 에 따라 분기하고, 연속 구간을 bodyBuf 에 순서대로 기록합니다. (ko) + // Receiver-side ARQ: handle Seq and append contiguous data to bodyBuf in order. (en) + switch { + case sd.Seq == expectedSeq: + if len(sd.Data) > 0 { + if _, err := bodyBuf.Write(sd.Data); err != nil { + return nil, fmt.Errorf("buffer stream_data response: %w", err) + } } + expectedSeq++ + for { + data, ok := received[expectedSeq] + if !ok { + break + } + if len(data) > 0 { + if _, err := bodyBuf.Write(data); err != nil { + return nil, fmt.Errorf("buffer reordered stream_data response: %w", err) + } + } + delete(received, expectedSeq) + delete(lost, expectedSeq) + expectedSeq++ + } + + // AckSeq 이전 구간의 lost 항목 정리 + for seq := range lost { + if seq < expectedSeq { + delete(lost, seq) + } + } + + case sd.Seq > expectedSeq: + // 앞선 일부 Seq 들이 누락된 상태: 현재 프레임을 버퍼링하고 missing seq 들을 lost 에 추가. (ko) + // Missing earlier Seq: buffer this frame and mark missing seqs as lost. (en) + if len(sd.Data) > 0 { + bufCopy := make([]byte, len(sd.Data)) + copy(bufCopy, sd.Data) + received[sd.Seq] = bufCopy + } + for seq := expectedSeq; seq < sd.Seq && len(lost) < maxLostReport; seq++ { + if _, ok := lost[seq]; !ok { + lost[seq] = struct{}{} + } + } + + default: + // sd.Seq < expectedSeq 인 경우: 이미 처리했거나 Ack 로 커버된 프레임 → 무시. (ko) + // sd.Seq < expectedSeq: already processed/acked frame → ignore. (en) + } + + // 수신 측 StreamAck 전송: + // - AckSeq: 0부터 시작해 연속으로 수신 완료한 마지막 시퀀스 (expectedSeq-1) + // - LostSeqs: 현재 윈도우 내에서 누락된 시퀀스 중 상한 개수(maxLostReport)까지만 포함 (ko) + // Send receiver-side StreamAck: + // - AckSeq: last contiguously received sequence starting from 0 (expectedSeq-1) + // - LostSeqs: up to maxLostReport missing sequences in the current window. (en) + var ackSeq uint64 + if expectedSeq == 0 { + ackSeq = 0 + } else { + ackSeq = expectedSeq - 1 + } + + lostSeqs := make([]uint64, 0, len(lost)) + for seq := range lost { + if seq >= expectedSeq { + lostSeqs = append(lostSeqs, seq) + } + } + if len(lostSeqs) > 0 { + sort.Slice(lostSeqs, func(i, j int) bool { return lostSeqs[i] < lostSeqs[j] }) + if len(lostSeqs) > maxLostReport { + lostSeqs = lostSeqs[:maxLostReport] + } + } + + ackEnv := protocol.Envelope{ + Type: protocol.MessageTypeStreamAck, + StreamAck: &protocol.StreamAck{ + ID: streamID, + AckSeq: ackSeq, + LostSeqs: lostSeqs, + }, + } + if err := w.codec.Encode(w.sess, &ackEnv); err != nil { + return nil, fmt.Errorf("send stream ack: %w", err) } case protocol.MessageTypeStreamClose: @@ -630,6 +887,7 @@ func registerSessionForDomain(domain string, sess dtls.Session, logger logging.L logger: logger.With(logging.Fields{"component": "dtls_session_wrapper", "domain": d}), pending: make(map[protocol.StreamID]*pendingRequest), readerDone: make(chan struct{}), + streamSenders: make(map[protocol.StreamID]*streamSender), } // Start background reader goroutine to demultiplex incoming responses