diff --git a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt index 479587c..47a4e6f 100644 --- a/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt +++ b/webserver/src/main/kotlin/space/mori/chzzk_bot/webserver/routes/WSSongListRoutes.kt @@ -1,332 +1,346 @@ package space.mori.chzzk_bot.webserver.routes import io.ktor.client.plugins.websocket.WebSocketException -import io.ktor.server.application.ApplicationStopped +import io.ktor.server.application.* import io.ktor.server.routing.* import io.ktor.server.sessions.* import io.ktor.server.websocket.* -import io.ktor.util.logging.Logger -import io.ktor.utils.io.CancellationException import io.ktor.websocket.* -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel +import kotlinx.coroutines.* import kotlinx.coroutines.channels.ClosedReceiveChannelException -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch -import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.coroutines.withTimeoutOrNull -import kotlinx.io.IOException import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json import org.koin.java.KoinJavaComponent.inject +import org.slf4j.Logger import org.slf4j.LoggerFactory import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.User -import space.mori.chzzk_bot.common.services.SongConfigService import space.mori.chzzk_bot.common.services.SongListService import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.utils.CurrentSong +import java.io.IOException import java.util.concurrent.ConcurrentHashMap -val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) fun Routing.wsSongListRoutes() { - val sessions = ConcurrentHashMap() - val status = ConcurrentHashMap() val logger = LoggerFactory.getLogger("WSSongListRoutes") val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) - val sessionMutex = Mutex() + val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + // Manage all active sessions + val sessionHandlers = ConcurrentHashMap() + + // Handle application shutdown environment.monitor.subscribe(ApplicationStopped) { - songListScope.cancel() - } - - val ackMap = ConcurrentHashMap>() - - suspend fun addSession(uid: String, session: WebSocketServerSession) { - val oldSession = sessionMutex.withLock { - val old = sessions[uid] - sessions[uid] = session - old - } - if(oldSession != null) { + sessionHandlers.values.forEach { songListScope.launch { - try { - oldSession.close(CloseReason( - CloseReason.Codes.VIOLATED_POLICY, "Another session is already active.")) - } catch(e: Exception) { - logger.warn("Error closing old session: ${e.message}") - } + it.close(CloseReason(CloseReason.Codes.NORMAL, "Server shutting down")) } } } - suspend fun removeSession(uid: String) { - sessionMutex.withLock { - sessions.remove(uid) - } - } - - suspend fun waitForAck(ws: WebSocketServerSession, expectedUid: String): Boolean { - val ackDeferred = CompletableDeferred() - ackMap[expectedUid] = ackDeferred - return try { - withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false - } catch (e: CancellationException) { - false - } finally { - ackMap.remove(expectedUid) - } - } - - suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) { - var attempt = 0 - var sentSuccessfully = false - while (attempt < maxRetries && !sentSuccessfully) { - val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] } - if (ws == null) { - logger.debug("No active session for $uid. Retrying in $delayMillis ms.") - delay(delayMillis) - attempt++ - continue - } - try { - ws.sendSerialized(res) - val ackReceived = waitForAck(ws, res.uid) - if (ackReceived) { - logger.debug("ACK received for message to $uid on attempt $attempt.") - sentSuccessfully = true - } else { - logger.warn("ACK not received for message to $uid on attempt $attempt.") - attempt++ - } - } catch (e: CancellationException) { - throw e - } catch (e: Exception) { - attempt++ - logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}") - if (e is WebSocketException || e is IOException) { - removeSession(uid) - } - } - if (!sentSuccessfully && attempt < maxRetries) { - delay(delayMillis) - } - } - if (!sentSuccessfully) { - logger.error("Failed to send message to $uid after $maxRetries attempts.") - } - } - + // WebSocket endpoint webSocket("/songlist") { val session = call.sessions.get() - val user = session?.id?.let { UserService.getUser(it) } + val user: User? = session?.id?.let { UserService.getUser(it) } + if (user == null) { close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) return@webSocket } + val uid = user.token - addSession(uid, this) - if (status[uid] == SongType.STREAM_OFF) { - songListScope.launch { - try { - sendSerialized(SongResponse( - SongType.STREAM_OFF.value, - uid, - null, - null, - null, - )) - } catch (e: Exception) { - logger.warn("Error sending STREAM_OFF: ${e.message}") - } finally { - removeSession(uid) - } - } - return@webSocket - } + + // Ensure only one session per user + sessionHandlers[uid]?.close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Another session is already active.")) + + val handler = SessionHandler(uid, this, dispatcher, logger) + sessionHandlers[uid] = handler + + // Initialize session + handler.initialize() + + // Listen for incoming frames try { for (frame in incoming) { when (frame) { - is Frame.Text -> { - val text = frame.readText() - if (text.trim() == "ping") { - send("pong") - } else { - val data = Json.decodeFromString(text) - if (data.type == SongType.ACK.value) { - ackMap[data.uid]?.complete(true) - } else { - handleSongRequest(data, user, dispatcher, logger) - } - } - } + is Frame.Text -> handler.handleTextFrame(frame.readText()) is Frame.Ping -> send(Frame.Pong(frame.data)) - else -> {} + else -> Unit } } } catch (e: ClosedReceiveChannelException) { - logger.error("WebSocket connection closed: ${e.message}") + logger.info("Session closed: ${e.message}") + } catch (e: IOException) { + logger.error("IO error: ${e.message}") } catch (e: Exception) { - logger.error("Error in WebSocket: ${e.message}") + logger.error("Unexpected error: ${e.message}") } finally { - removeSession(uid) - ackMap.remove(uid) + sessionHandlers.remove(uid) + handler.close(CloseReason(CloseReason.Codes.NORMAL, "Session ended")) } } - dispatcher.subscribe(SongEvent::class) { - logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) + // Subscribe to SongEvents + dispatcher.subscribe(SongEvent::class) { event -> + val handler = sessionHandlers[event.uid] songListScope.launch { - try { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - it.reqUid, - it.current?.toSerializable(), - it.next?.toSerializable(), - it.delUrl - ) - ) - } - } catch(e: Exception) { - logger.error("Error handling song event: ${e.message}") - } + handler?.sendSongResponse(event) } } - dispatcher.subscribe(TimerEvent::class) { - if (it.type == TimerType.STREAM_OFF) { + + // Subscribe to TimerEvents + dispatcher.subscribe(TimerEvent::class) { event -> + if (event.type == TimerType.STREAM_OFF) { + val handler = sessionHandlers[event.uid] songListScope.launch { - try { - val user = UserService.getUser(it.uid) - if (user != null) { - sendWithRetry( - user.token, SongResponse( - it.type.value, - it.uid, - null, - null, - null, - ) - ) - } - } catch(e: Exception) { - logger.error("Error handling timer event: ${e.message}") - } + handler?.sendTimerOff() } } } } -// 노래 처리를 위한 Mutex 추가 -private val songMutex = Mutex() -fun handleSongRequest( - data: SongRequest, - user: User, - dispatcher: CoroutinesEventBus, - logger: Logger + +class SessionHandler( + private val uid: String, + private val session: WebSocketServerSession, + private val dispatcher: CoroutinesEventBus, + private val logger: Logger ) { - if (data.maxQueue != null && data.maxQueue > 0) SongConfigService.updateQueueLimit(user, data.maxQueue) - if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit) - if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly) - if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled) - when (data.type) { - SongType.ADD.value -> { - data.url?.let { url -> - try { - val youtubeVideo = getYoutubeVideo(url) - if (youtubeVideo != null) { - songListScope.launch { - songMutex.withLock { - SongListService.saveSong( - user, - user.token, - url, - youtubeVideo.name, - youtubeVideo.author, - youtubeVideo.length, - user.username - ) - dispatcher.post( - SongEvent( - user.token, - SongType.ADD, - user.token, - CurrentSong.getSong(user), - youtubeVideo - ) - ) - } - } - } - } catch (e: Exception) { - logger.debug("SongType.ADD Error: ${user.token} $e") + private val ackMap = ConcurrentHashMap>() + private val sessionMutex = Mutex() + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + + suspend fun initialize() { + // Send initial status if needed, + // For example, send STREAM_OFF if applicable + // This can be extended based on your requirements + } + + suspend fun handleTextFrame(text: String) { + if (text.trim() == "ping") { + session.send("pong") + return + } + + val data = try { + Json.decodeFromString(text) + } catch (e: Exception) { + logger.warn("Failed to decode SongRequest: ${e.message}") + return + } + + when (data.type) { + SongType.ACK.value -> handleAck(data.uid) + else -> handleSongRequest(data) + } + } + + private fun handleAck(requestUid: String) { + ackMap[requestUid]?.complete(true) + ackMap.remove(requestUid) + } + + private fun handleSongRequest(data: SongRequest) { + scope.launch { + SongRequestProcessor.process(data, uid, dispatcher, this@SessionHandler, logger) + } + } + + suspend fun sendSongResponse(event: SongEvent) { + val response = SongResponse( + type = event.type.value, + uid = event.uid, + reqUid = event.reqUid, + current = event.current?.toSerializable(), + next = event.next?.toSerializable(), + delUrl = event.delUrl + ) + sendWithRetry(response) + } + + suspend fun sendTimerOff() { + val response = SongResponse( + type = TimerType.STREAM_OFF.value, + uid = uid, + reqUid = null, + current = null, + next = null, + delUrl = null + ) + sendWithRetry(response) + } + + private suspend fun sendWithRetry(res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) { + var attempt = 0 + while (attempt < maxRetries) { + try { + session.sendSerialized(res) + val ackDeferred = CompletableDeferred() + ackMap[res.uid] = ackDeferred + + val ackReceived = withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false + if (ackReceived) { + logger.debug("ACK received for message to $uid on attempt $attempt.") + return + } else { + logger.warn("ACK not received for message to $uid on attempt $attempt.") } + } catch (e: IOException) { + logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}") + if (e is WebSocketException) { + close(CloseReason(CloseReason.Codes.PROTOCOL_ERROR, "WebSocket error")) + return + } + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + logger.warn("Unexpected error while sending message to $uid on attempt $attempt: ${e.message}") + } + + attempt++ + delay(delayMillis) + } + + logger.error("Failed to send message to $uid after $maxRetries attempts.") + } + + suspend fun close(reason: CloseReason) { + try { + session.close(reason) + } catch (e: Exception) { + logger.warn("Error closing session: ${e.message}") + } + } +} + +object SongRequestProcessor { + private val songMutex = Mutex() + + suspend fun process( + data: SongRequest, + uid: String, + dispatcher: CoroutinesEventBus, + handler: SessionHandler, + logger: Logger + ) { + val user = UserService.getUser(uid) ?: return + + when (data.type) { + SongType.ADD.value -> handleAdd(data, user, dispatcher, handler, logger) + SongType.REMOVE.value -> handleRemove(data, user, dispatcher, logger) + SongType.NEXT.value -> handleNext(user, dispatcher, logger) + else -> { + // Handle other types if necessary } } - SongType.REMOVE.value -> { - data.url?.let { url -> - songListScope.launch { - songMutex.withLock { - val songs = SongListService.getSong(user) - val exactSong = songs.firstOrNull { it.url == url } - if (exactSong != null) { - SongListService.deleteSong(user, exactSong.uid, exactSong.name) - } - dispatcher.post( - SongEvent( - user.token, - SongType.REMOVE, - null, - null, - null, - url - ) - ) - } - } + } + + private suspend fun handleAdd( + data: SongRequest, + user: User, + dispatcher: CoroutinesEventBus, + handler: SessionHandler, + logger: Logger + ) { + val url = data.url ?: return + val youtubeVideo = getYoutubeVideo(url) ?: run { + logger.warn("Failed to fetch YouTube video for URL: $url") + return + } + + songMutex.withLock { + SongListService.saveSong( + user, + user.token, + url, + youtubeVideo.name, + youtubeVideo.author, + youtubeVideo.length, + user.username + ) + } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.ADD, + reqUid = user.token, + current = CurrentSong.getSong(user), + next = youtubeVideo + ) + ) + } + + private suspend fun handleRemove( + data: SongRequest, + user: User, + dispatcher: CoroutinesEventBus, + logger: Logger + ) { + val url = data.url ?: return + + songMutex.withLock { + val songs = SongListService.getSong(user) + val exactSong = songs.firstOrNull { it.url == url } + if (exactSong != null) { + SongListService.deleteSong(user, exactSong.uid, exactSong.name) } } - SongType.NEXT.value -> { - songListScope.launch { - songMutex.withLock { - val songList = SongListService.getSong(user) - var song: SongList? = null - var youtubeVideo: YoutubeVideo? = null - if (songList.isNotEmpty()) { - song = songList[0] - SongListService.deleteSong(user, song.uid, song.name) - } - song?.let { - youtubeVideo = YoutubeVideo( - song.url, - song.name, - song.author, - song.time - ) - } - dispatcher.post( - SongEvent( - user.token, - SongType.NEXT, - song?.uid, - youtubeVideo - ) - ) - CurrentSong.setSong(user, youtubeVideo) - } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.REMOVE, + delUrl = url, + reqUid = null, + current = null, + next = null, + ) + ) + } + + private suspend fun handleNext( + user: User, + dispatcher: CoroutinesEventBus, + logger: Logger + ) { + var song: SongList? = null + var youtubeVideo: YoutubeVideo? = null + + songMutex.withLock { + val songList = SongListService.getSong(user) + if (songList.isNotEmpty()) { + song = songList[0] + SongListService.deleteSong(user, song.uid, song.name) } } + + song?.let { + youtubeVideo = YoutubeVideo( + it.url, + it.name, + it.author, + it.time + ) + } + + dispatcher.post( + SongEvent( + uid = user.token, + type = SongType.NEXT, + current = null, + next = youtubeVideo, + reqUid = null, + delUrl = null + ) + ) + + CurrentSong.setSong(user, youtubeVideo) } } @@ -334,10 +348,10 @@ fun handleSongRequest( data class SongRequest( val type: Int, val uid: String, - val url: String?, - val maxQueue: Int?, - val maxUserLimit: Int?, - val isStreamerOnly: Boolean?, - val remove: Int?, - val isDisabled: Boolean?, -) \ No newline at end of file + val url: String? = null, + val maxQueue: Int? = null, + val maxUserLimit: Int? = null, + val isStreamerOnly: Boolean? = null, + val remove: Int? = null, + val isDisabled: Boolean? = null +)