From aa959760052e787a6c9c81353a27157e7224820b Mon Sep 17 00:00:00 2001 From: dalbodeule <11470513+dalbodeule@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:45:25 +0900 Subject: [PATCH] Refactor WebSocket song list handling and improve session logic Replaced individual WebSocket session management with `SessionHandler` to centralize and streamline logic. Improved code readability, reliability, and maintainability by reducing redundancy and encapsulating session and request handling in dedicated classes. Added retry mechanisms, acknowledgment handling, and better application shutdown handling. --- .../webserver/routes/WSSongListRoutes.kt | 546 +++++++++--------- 1 file changed, 280 insertions(+), 266 deletions(-) diff --git a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt index 479587c..47a4e6f 100644 --- a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt +++ b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt @@ -1,332 +1,346 @@ package space.mori.chzzk_bot.webserver.routes import io.ktor.client.plugins.websocket.WebSocketException -import io.ktor.server.application.ApplicationStopped +import io.ktor.server.application.* import io.ktor.server.routing.* import io.ktor.server.sessions.* import io.ktor.server.websocket.* -import io.ktor.util.logging.Logger -import io.ktor.utils.io.CancellationException import io.ktor.websocket.* -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.* import kotlinx.coroutines.channels.ClosedReceiveChannelException -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch -import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.coroutines.withTimeoutOrNull -import kotlinx.io.IOException import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import org.koin.java.KoinJavaComponent.inject +import org.slf4j.Logger import org.slf4j.LoggerFactory import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.User -import space.mori.chzzk_bot.common.services.SongConfigService import space.mori.chzzk_bot.common.services.SongListService import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.utils.CurrentSong +import java.io.IOException import java.util.concurrent.ConcurrentHashMap -val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) fun Routing.wsSongListRoutes() { - val sessions = ConcurrentHashMap() - val status = ConcurrentHashMap() val logger = LoggerFactory.getLogger("WSSongListRoutes") val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) - val sessionMutex = Mutex() + val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + // Manage all active sessions + val sessionHandlers = ConcurrentHashMap() + + // Handle application shutdown environment.monitor.subscribe(ApplicationStopped) { - songListScope.cancel() - } - - val ackMap = ConcurrentHashMap>() - - suspend fun addSession(uid: String, session: WebSocketServerSession) { - val oldSession = sessionMutex.withLock { - val old = sessions[uid] - sessions[uid] = session - old - } - if(oldSession != null) { + sessionHandlers.values.forEach { songListScope.launch { - try { - oldSession.close(CloseReason( - CloseReason.Codes.VIOLATED_POLICY, "Another session is already active.")) - } catch(e: Exception) { - logger.warn("Error closing old session: ${e.message}") - } + it.close(CloseReason(CloseReason.Codes.NORMAL, "Server shutting down")) } } } - suspend fun removeSession(uid: String) { - sessionMutex.withLock { - sessions.remove(uid) - } - } - - suspend fun waitForAck(ws: WebSocketServerSession, expectedUid: String): Boolean { - val ackDeferred = CompletableDeferred() - ackMap[expectedUid] = ackDeferred - return try { - withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false - } catch (e: CancellationException) { - false - } finally { - ackMap.remove(expectedUid) - } - } - - suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) { - var attempt = 0 - var sentSuccessfully = false - while (attempt < maxRetries && !sentSuccessfully) { - val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] } - if (ws == null) { - logger.debug("No active session for $uid. Retrying in $delayMillis ms.") - delay(delayMillis) - attempt++ - continue - } - try { - ws.sendSerialized(res) - val ackReceived = waitForAck(ws, res.uid) - if (ackReceived) { - logger.debug("ACK received for message to $uid on attempt $attempt.") - sentSuccessfully = true - } else { - logger.warn("ACK not received for message to $uid on attempt $attempt.") - attempt++ - } - } catch (e: CancellationException) { - throw e - } catch (e: Exception) { - attempt++ - logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}") - if (e is WebSocketException || e is IOException) { - removeSession(uid) - } - } - if (!sentSuccessfully && attempt < maxRetries) { - delay(delayMillis) - } - } - if (!sentSuccessfully) { - logger.error("Failed to send message to $uid after $maxRetries attempts.") - } - } - + // WebSocket endpoint webSocket("/songlist") { val session = call.sessions.get() - val user = session?.id?.let { UserService.getUser(it) } + val user: User? = session?.id?.let { UserService.getUser(it) } + if (user == null) { close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) return@webSocket } + val uid = user.token - addSession(uid, this) - if (status[uid] == SongType.STREAM_OFF) { - songListScope.launch { - try { - sendSerialized(SongResponse( - SongType.STREAM_OFF.value, - uid, - null, - null, - null, - )) - } catch (e: Exception) { - logger.warn("Error sending STREAM_OFF: ${e.message}") - } finally { - removeSession(uid) - } - } - return@webSocket - } + + // Ensure only one session per user + sessionHandlers[uid]?.close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Another session is already active.")) + + val handler = SessionHandler(uid, this, dispatcher, logger) + sessionHandlers[uid] = handler + + // Initialize session + handler.initialize() + + // Listen for incoming frames try { for (frame in incoming) { when (frame) { - is Frame.Text -> { - val text = frame.readText() - if (text.trim() == "ping") { - send("pong") - } else { - val data = Json.decodeFromString(text) - if (data.type == SongType.ACK.value) { - ackMap[data.uid]?.complete(true) - } else { - handleSongRequest(data, user, dispatcher, logger) - } - } - } + is Frame.Text -> handler.handleTextFrame(frame.readText()) is Frame.Ping -> send(Frame.Pong(frame.data)) - else -> {} + else -> Unit } } } catch (e: ClosedReceiveChannelException) { - logger.error("WebSocket connection closed: ${e.message}") + logger.info("Session closed: ${e.message}") + } catch (e: IOException) { + logger.error("IO error: ${e.message}") } catch (e: Exception) { - logger.error("Error in WebSocket: ${e.message}") + logger.error("Unexpected error: ${e.message}") } finally { - removeSession(uid) - ackMap.remove(uid) + sessionHandlers.remove(uid) + handler.close(CloseReason(CloseReason.Codes.NORMAL, "Session ended")) } } - dispatcher.subscribe(SongEvent::class) { - logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) + // Subscribe to SongEvents + dispatcher.subscribe(SongEvent::class) { event -> + val handler = sessionHandlers[event.uid] songListScope.launch { - try { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - it.reqUid, - it.current?.toSerializable(), - it.next?.toSerializable(), - it.delUrl - ) - ) - } - } catch(e: Exception) { - logger.error("Error handling song event: ${e.message}") - } + handler?.sendSongResponse(event) } } - dispatcher.subscribe(TimerEvent::class) { - if (it.type == TimerType.STREAM_OFF) { + + // Subscribe to TimerEvents + dispatcher.subscribe(TimerEvent::class) { event -> + if (event.type == TimerType.STREAM_OFF) { + val handler = sessionHandlers[event.uid] songListScope.launch { - try { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - null, - null, - null, - ) - ) - } - } catch(e: Exception) { - logger.error("Error handling timer event: ${e.message}") - } + handler?.sendTimerOff() } } } } -// 노래 처리를 위한 Mutex 추가 -private val songMutex = Mutex() -fun handleSongRequest( - data: SongRequest, - user: User, - dispatcher: CoroutinesEventBus, - logger: Logger + +class SessionHandler( + private val uid: String, + private val session: WebSocketServerSession, + private val dispatcher: CoroutinesEventBus, + private val logger: Logger ) { - if (data.maxQueue != null && data.maxQueue > 0) SongConfigService.updateQueueLimit(user, data.maxQueue) - if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit) - if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly) - if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled) - when (data.type) { - SongType.ADD.value -> { - data.url?.let { url -> - try { - val youtubeVideo = getYoutubeVideo(url) - if (youtubeVideo != null) { - songListScope.launch { - songMutex.withLock { - SongListService.saveSong( - user, - user.token, - url, - youtubeVideo.name, - youtubeVideo.author, - youtubeVideo.length, - user.username - ) - dispatcher.post( - SongEvent( - user.token, - SongType.ADD, - user.token, - CurrentSong.getSong(user), - youtubeVideo - ) - ) - } - } - } - } catch (e: Exception) { - logger.debug("SongType.ADD Error: ${user.token} $e") + private val ackMap = ConcurrentHashMap>() + private val sessionMutex = Mutex() + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + + suspend fun initialize() { + // Send initial status if needed, + // For example, send STREAM_OFF if applicable + // This can be extended based on your requirements + } + + suspend fun handleTextFrame(text: String) { + if (text.trim() == "ping") { + session.send("pong") + return + } + + val data = try { + Json.decodeFromString(text) + } catch (e: Exception) { + logger.warn("Failed to decode SongRequest: ${e.message}") + return + } + + when (data.type) { + SongType.ACK.value -> handleAck(data.uid) + else -> handleSongRequest(data) + } + } + + private fun handleAck(requestUid: String) { + ackMap[requestUid]?.complete(true) + ackMap.remove(requestUid) + } + + private fun handleSongRequest(data: SongRequest) { + scope.launch { + SongRequestProcessor.process(data, uid, dispatcher, this@SessionHandler, logger) + } + } + + suspend fun sendSongResponse(event: SongEvent) { + val response = SongResponse( + type = event.type.value, + uid = event.uid, + reqUid = event.reqUid, + current = event.current?.toSerializable(), + next = event.next?.toSerializable(), + delUrl = event.delUrl + ) + sendWithRetry(response) + } + + suspend fun sendTimerOff() { + val response = SongResponse( + type = TimerType.STREAM_OFF.value, + uid = uid, + reqUid = null, + current = null, + next = null, + delUrl = null + ) + sendWithRetry(response) + } + + private suspend fun sendWithRetry(res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) { + var attempt = 0 + while (attempt < maxRetries) { + try { + session.sendSerialized(res) + val ackDeferred = CompletableDeferred() + ackMap[res.uid] = ackDeferred + + val ackReceived = withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false + if (ackReceived) { + logger.debug("ACK received for message to $uid on attempt $attempt.") + return + } else { + logger.warn("ACK not received for message to $uid on attempt $attempt.") } + } catch (e: IOException) { + logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}") + if (e is WebSocketException) { + close(CloseReason(CloseReason.Codes.PROTOCOL_ERROR, "WebSocket error")) + return + } + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.warn("Unexpected error while sending message to $uid on attempt $attempt: ${e.message}") + } + + attempt++ + delay(delayMillis) + } + + logger.error("Failed to send message to $uid after $maxRetries attempts.") + } + + suspend fun close(reason: CloseReason) { + try { + session.close(reason) + } catch (e: Exception) { + logger.warn("Error closing session: ${e.message}") + } + } +} + +object SongRequestProcessor { + private val songMutex = Mutex() + + suspend fun process( + data: SongRequest, + uid: String, + dispatcher: CoroutinesEventBus, + handler: SessionHandler, + logger: Logger + ) { + val user = UserService.getUser(uid) ?: return + + when (data.type) { + SongType.ADD.value -> handleAdd(data, user, dispatcher, handler, logger) + SongType.REMOVE.value -> handleRemove(data, user, dispatcher, logger) + SongType.NEXT.value -> handleNext(user, dispatcher, logger) + else -> { + // Handle other types if necessary } } - SongType.REMOVE.value -> { - data.url?.let { url -> - songListScope.launch { - songMutex.withLock { - val songs = SongListService.getSong(user) - val exactSong = songs.firstOrNull { it.url == url } - if (exactSong != null) { - SongListService.deleteSong(user, exactSong.uid, exactSong.name) - } - dispatcher.post( - SongEvent( - user.token, - SongType.REMOVE, - null, - null, - null, - url - ) - ) - } - } + } + + private suspend fun handleAdd( + data: SongRequest, + user: User, + dispatcher: CoroutinesEventBus, + handler: SessionHandler, + logger: Logger + ) { + val url = data.url ?: return + val youtubeVideo = getYoutubeVideo(url) ?: run { + logger.warn("Failed to fetch YouTube video for URL: $url") + return + } + + songMutex.withLock { + SongListService.saveSong( + user, + user.token, + url, + youtubeVideo.name, + youtubeVideo.author, + youtubeVideo.length, + user.username + ) + } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.ADD, + reqUid = user.token, + current = CurrentSong.getSong(user), + next = youtubeVideo + ) + ) + } + + private suspend fun handleRemove( + data: SongRequest, + user: User, + dispatcher: CoroutinesEventBus, + logger: Logger + ) { + val url = data.url ?: return + + songMutex.withLock { + val songs = SongListService.getSong(user) + val exactSong = songs.firstOrNull { it.url == url } + if (exactSong != null) { + SongListService.deleteSong(user, exactSong.uid, exactSong.name) } } - SongType.NEXT.value -> { - songListScope.launch { - songMutex.withLock { - val songList = SongListService.getSong(user) - var song: SongList? = null - var youtubeVideo: YoutubeVideo? = null - if (songList.isNotEmpty()) { - song = songList[0] - SongListService.deleteSong(user, song.uid, song.name) - } - song?.let { - youtubeVideo = YoutubeVideo( - song.url, - song.name, - song.author, - song.time - ) - } - dispatcher.post( - SongEvent( - user.token, - SongType.NEXT, - song?.uid, - youtubeVideo - ) - ) - CurrentSong.setSong(user, youtubeVideo) - } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.REMOVE, + delUrl = url, + reqUid = null, + current = null, + next = null, + ) + ) + } + + private suspend fun handleNext( + user: User, + dispatcher: CoroutinesEventBus, + logger: Logger + ) { + var song: SongList? = null + var youtubeVideo: YoutubeVideo? = null + + songMutex.withLock { + val songList = SongListService.getSong(user) + if (songList.isNotEmpty()) { + song = songList[0] + SongListService.deleteSong(user, song.uid, song.name) } } + + song?.let { + youtubeVideo = YoutubeVideo( + it.url, + it.name, + it.author, + it.time + ) + } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.NEXT, + current = null, + next = youtubeVideo, + reqUid = null, + delUrl = null + ) + ) + + CurrentSong.setSong(user, youtubeVideo) } } @@ -334,10 +348,10 @@ fun handleSongRequest( data class SongRequest( val type: Int, val uid: String, - val url: String?, - val maxQueue: Int?, - val maxUserLimit: Int?, - val isStreamerOnly: Boolean?, - val remove: Int?, - val isDisabled: Boolean?, -) \ No newline at end of file + val url: String? = null, + val maxQueue: Int? = null, + val maxUserLimit: Int? = null, + val isStreamerOnly: Boolean? = null, + val remove: Int? = null, + val isDisabled: Boolean? = null +)