From 02cede87f873f3d3c37f0e745c16ba373e410368 Mon Sep 17 00:00:00 2001 From: dalbodeule <11470513+dalbodeule@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:58:56 +0900 Subject: [PATCH] Add SongListWebSocketManager and refactor WebSocket routes Introduced SongListWebSocketManager for managing WebSocket sessions, including ping-pong handling and retry mechanisms. Refactored WSSongListRoutes to delegate session management and simplify logic by leveraging the new manager class. --- .../webserver/routes/WSSongListRoutes.kt | 178 ++---------- .../utils/SongListWebSocketManager.kt | 271 ++++++++++++++++++ 2 files changed, 299 insertions(+), 150 deletions(-) create mode 100644 webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/utils/SongListWebSocketManager.kt diff --git a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt index 01b145b..479d260 100644 --- a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt +++ b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt @@ -3,16 +3,14 @@ import io.ktor.server.sessions.* import io.ktor.server.websocket.* import io.ktor.util.logging.Logger import io.ktor.websocket.* +import io.ktor.server.application.* import io.ktor.websocket.Frame.* import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.channels.ClosedReceiveChannelException -import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import org.koin.java.KoinJavaComponent.inject -import org.slf4j.LoggerFactory import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.SongLists.uid @@ -23,178 +21,57 @@ import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.webserver.UserSession -import space.mori.chzzk_bot.webserver.routes.SongResponse -import space.mori.chzzk_bot.webserver.routes.toSerializable import space.mori.chzzk_bot.webserver.utils.CurrentSong -import java.util.concurrent.ConcurrentHashMap - -fun Routing.wsSongListRoutes() { - val sessions = ConcurrentHashMap() - val status = ConcurrentHashMap() - val logger = LoggerFactory.getLogger("WSSongListRoutes") +import space.mori.chzzk_bot.webserver.utils.SongListWebSocketManager +fun Route.wsSongListRoutes(songListWebSocketManager: SongListWebSocketManager) { val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) - fun addSession(uid: String, session: WebSocketServerSession) { - if (sessions[uid] != null) { - CoroutineScope(Dispatchers.Default).launch { - sessions[uid]?.close( - CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Duplicated sessions.") - ) - } - } - sessions[uid] = session - } - - fun removeSession(uid: String) { - sessions.remove(uid) - } - - suspend fun waitForAck(ws: WebSocketServerSession, expectedType: Int): Boolean { - val timeout = 5000L // 5 seconds timeout - val startTime = System.currentTimeMillis() - while (System.currentTimeMillis() - startTime < timeout) { - for (frame in ws.incoming) { - if (frame is Text) { - val message = frame.readText() - if(message == "ping") { - return true - } - val data = Json.decodeFromString(message) - if (data.type == SongType.ACK.value) { - return true // ACK received - } - } - } - delay(100) // Check every 100 ms - } - return false // Timeout - } - - suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) { - var attempt = 0 - var sentSuccessfully = false - - while (attempt < maxRetries && !sentSuccessfully) { - val ws = sessions[uid] - try { - if(ws == null) { - delay(delayMillis) - continue - } - // Attempt to send the message - ws.sendSerialized(res) - logger.debug("Message sent successfully to $uid on attempt $attempt") - // Wait for ACK - val ackReceived = waitForAck(ws, res.type) - if (ackReceived == true) { - sentSuccessfully = true - } else { - logger.warn("ACK not received for message to $uid on attempt $attempt.") - } - } catch (e: Exception) { - attempt++ - logger.warn("Failed to send message to $uid on attempt $attempt. Retrying in $delayMillis ms.") - logger.warn(e.stackTraceToString()) - } finally { - // Wait before retrying - delay(delayMillis) - } - } - - if (!sentSuccessfully) { - logger.error("Failed to send message to $uid after $maxRetries attempts.") - } - } - webSocket("/songlist") { val session = call.sessions.get() - val user = session?.id?.let { UserService.getUser(it) } - if (user == null) { - close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) + val uid = session?.id + if (uid == null) { + close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid session")) return@webSocket } - val uid = user.token - - addSession(uid, this) - - if (status[uid] == SongType.STREAM_OFF) { - CoroutineScope(Dispatchers.Default).launch { - sendSerialized(SongResponse( - SongType.STREAM_OFF.value, - uid, - null, - null, - null, - )) - } - removeSession(uid) - } - try { + songListWebSocketManager.addSession(uid, this) for (frame in incoming) { when (frame) { is Text -> { - if (frame.readText().trim() == "ping") { - send("pong") + val text = frame.readText().trim() + if (text == SongListWebSocketManager.PING_MESSAGE) { + send(SongListWebSocketManager.PONG_MESSAGE) + songListWebSocketManager.handlePong(uid) } else { - val data = frame.readText().let { Json.decodeFromString(it) } - // Handle song requests - handleSongRequest(data, user, dispatcher, logger) + text.let { Json.decodeFromString(it) }.let { data -> + val user = session.id.let { UserService.getUser(it) } + + if(user == null) { + songListWebSocketManager.removeSession(uid) + return@webSocket + } + + handleSongRequest(data, user, dispatcher, songListWebSocketManager.logger) + }.runCatching { songListWebSocketManager.logger.error("Failed to parse WebSocket message as SongRequest.") } } } + is Ping -> send(Pong(frame.data)) - else -> "" + else -> songListWebSocketManager.logger.warn("Unsupported frame type received.") } } - } catch (e: ClosedReceiveChannelException) { - logger.error("Error in WebSocket: ${e.message}") + } catch (e: Exception) { + songListWebSocketManager.logger.error("WebSocket error: ${e.message}") } finally { - removeSession(uid) - } - } - - dispatcher.subscribe(SongEvent::class) { - logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) - CoroutineScope(Dispatchers.Default).launch { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - it.reqUid, - it.current?.toSerializable(), - it.next?.toSerializable(), - it.delUrl - ) - ) - } - } - } - - dispatcher.subscribe(TimerEvent::class) { - if (it.type == TimerType.STREAM_OFF) { - CoroutineScope(Dispatchers.Default).launch { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - null, - null, - null, - ) - ) - } - } + songListWebSocketManager.removeSession(uid) } } } + suspend fun handleSongRequest( data: SongRequest, user: User, @@ -289,6 +166,7 @@ suspend fun handleSongRequest( } } + @Serializable data class SongRequest( val type: Int, diff --git a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/utils/SongListWebSocketManager.kt b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/utils/SongListWebSocketManager.kt new file mode 100644 index 0000000..57c9097 --- /dev/null +++ b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/utils/SongListWebSocketManager.kt @@ -0,0 +1,271 @@ +package space.mori.chzzk_bot.webserver.utils + +import io.ktor.server.websocket.WebSocketServerSession +import io.ktor.util.logging.Logger +import io.ktor.websocket.CloseReason +import io.ktor.websocket.Frame +import io.ktor.websocket.close +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.ClosedReceiveChannelException +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.serialization.json.Json +import space.mori.chzzk_bot.common.events.SongType +import space.mori.chzzk_bot.webserver.routes.SongResponse +import java.util.concurrent.ConcurrentHashMap +import kotlin.collections.set + +class SongListWebSocketManager(internal val logger: Logger) { + companion object { + private const val ACK_TIMEOUT_MS = 5000L + private const val ACK_CHECK_INTERVAL_MS = 100L + private const val MAX_RETRY_ATTEMPTS = 5 + private const val RETRY_DELAY_MS = 3000L + internal const val PING_MESSAGE = "ping" + internal const val PONG_MESSAGE = "pong" + private const val KEEP_ALIVE_INTERVAL_MS = 30000L // 30초 간격으로 핑-퐁 메시지 전송 + } + + private val sessions = ConcurrentHashMap() + private val status = ConcurrentHashMap() + private val lastActivity = ConcurrentHashMap() + private val pendingAcks = ConcurrentHashMap() + + // 핑-퐁 상태 관리를 위한 맵 + private val pingStatus = ConcurrentHashMap() + private val pingJobs = ConcurrentHashMap() + private val mutex = kotlinx.coroutines.sync.Mutex() + + init { + // 정기적으로 비활성 연결을 체크하는 백그라운드 작업 + CoroutineScope(Dispatchers.Default).launch { + while (true) { + try { + checkInactiveSessions() + delay(60000) // 1분마다 체크 + } catch (e: Exception) { + logger.error("비활성 세션 체크 중 오류 발생: ${e.message}") + } + } + } + } + + suspend fun addSession(uid: String, session: WebSocketServerSession) { + mutex.lock() + try { + sessions[uid] = session + lastActivity[uid] = System.currentTimeMillis() + + // 핑-퐁 작업 시작 + startPingPongJob(uid, session) + + logger.info("웹소켓 세션 추가됨: $uid") + } finally { + mutex.unlock() + } + } + + suspend fun removeSession(uid: String) { + mutex.lock() + try { + sessions.remove(uid) + status.remove(uid) + lastActivity.remove(uid) + pendingAcks.remove(uid) + + // 핑-퐁 작업 중지 + stopPingPongJob(uid) + + logger.info("웹소켓 세션 제거됨: $uid") + } finally { + mutex.unlock() + } + } + + private fun startPingPongJob(uid: String, session: WebSocketServerSession) { + pingStatus[uid] = true + pingJobs[uid] = CoroutineScope(Dispatchers.IO).launch { + while (pingStatus[uid] == true && sessions[uid] != null) { + try { + session.send(Frame.Text(PING_MESSAGE)) + logger.debug("핑 메시지 전송: $uid") + + // 응답 대기 (타임아웃 처리는 메시지 수신부에서) + mutex.lock() + try { + lastActivity[uid] = System.currentTimeMillis() + } finally { + mutex.unlock() + } + + delay(KEEP_ALIVE_INTERVAL_MS) + } catch (e: Exception) { + logger.error("핑-퐁 작업 중 오류 발생: ${e.message}") + + // 연결에 문제가 있으면 세션을 제거 + if (e is ClosedReceiveChannelException || e is java.io.IOException) { + logger.warn("연결 문제로 세션 제거: $uid") + removeSession(uid) + break + } + + delay(5000) // 오류 발생 시 잠시 대기 후 재시도 + } + } + } + } + + private fun stopPingPongJob(uid: String) { + pingStatus[uid] = false + pingJobs[uid]?.cancel() + pingJobs.remove(uid) + } + + // 비활성 세션 점검 및 제거 + private suspend fun checkInactiveSessions() { + val currentTime = System.currentTimeMillis() + val inactiveTimeout = 3 * KEEP_ALIVE_INTERVAL_MS // 3번의 핑-퐁 주기 이상 응답이 없으면 비활성으로 간주 + + val inactiveSessions = lastActivity.entries + .filter { currentTime - it.value > inactiveTimeout } + .map { it.key } + + mutex.lock() + try { + inactiveSessions.forEach { uid -> + logger.warn("비활성 세션 감지됨. 제거 중: $uid (마지막 활동: ${(currentTime - lastActivity[uid]!!) / 1000}초 전)") + + try { + sessions[uid]?.close(CloseReason(CloseReason.Codes.GOING_AWAY, "비활성 연결 감지")) + } catch (e: Exception) { + logger.error("비활성 세션 닫기 실패: ${e.message}") + } finally { + removeSession(uid) + } + } + } finally { + mutex.unlock() + } + + if (inactiveSessions.isNotEmpty()) { + logger.info("총 ${inactiveSessions.size}개의 비활성 세션이 제거됨") + } + } + + // 활동 기록 업데이트 + suspend fun updateActivity(uid: String) { + mutex.lock() + try { + lastActivity[uid] = System.currentTimeMillis() + } finally { + mutex.unlock() + } + } + + // 퐁 메시지 처리 + suspend fun handlePong(uid: String) { + updateActivity(uid) + logger.debug("퐁 메시지 수신: $uid") + } + + suspend fun waitForAck(ws: WebSocketServerSession): Boolean { + val sessionId = ws.hashCode().toString() + mutex.lock() + try { + pendingAcks[sessionId] = false + } finally { + mutex.unlock() + } + + val startTime = System.currentTimeMillis() + while (System.currentTimeMillis() - startTime < ACK_TIMEOUT_MS) { + mutex.lock() + try { + if (pendingAcks[sessionId] == true) { + pendingAcks.remove(sessionId) + return true + } + } finally { + mutex.unlock() + } + delay(ACK_CHECK_INTERVAL_MS) + } + + mutex.lock() + try { + pendingAcks.remove(sessionId) + } finally { + mutex.unlock() + } + return false + } + + suspend fun acknowledgeMessage(sessionId: String) { + mutex.lock() + try { + pendingAcks[sessionId] = true + } finally { + mutex.unlock() + } + } + + suspend fun sendWithRetry(uid: String, response: SongResponse) { + val session = sessions[uid] ?: run { + logger.warn("세션을 찾을 수 없음: $uid") + return + } + + var attempts = 0 + val jsonResponse = Json.encodeToString(SongResponse.serializer(), response) + + while (attempts < MAX_RETRY_ATTEMPTS) { + try { + session.send(Frame.Text(jsonResponse)) + updateActivity(uid) + + if (waitForAck(session)) { + logger.debug("메시지 전송 성공 (시도 ${attempts + 1}): $uid") + return + } else { + logger.warn("확인 응답 없음 (시도 ${attempts + 1}): $uid") + } + } catch (e: Exception) { + logger.error("메시지 전송 실패 (시도 ${attempts + 1}): ${e.message}") + + if (e is ClosedReceiveChannelException || e is java.io.IOException) { + logger.warn("연결 끊김으로 인한 세션 제거: $uid") + removeSession(uid) + return + } + } + + attempts++ + if (attempts < MAX_RETRY_ATTEMPTS) { + delay(RETRY_DELAY_MS) + } + } + + logger.error("최대 재시도 횟수 초과 후 메시지 전송 실패: $uid") + // 여러 번 실패 후 연결이 불안정한 것으로 판단하여 세션 제거 + removeSession(uid) + } + + suspend fun getUserStatus(uid: String): SongType? { + mutex.lock() + return try { + status[uid] + } finally { + mutex.unlock() + } + } + + suspend fun updateUserStatus(uid: String, songType: SongType) { + mutex.lock() + try { + status[uid] = songType + } finally { + mutex.unlock() + } + } +} \ No newline at end of file