Refactor WebSocket song list handling and improve session logic

Replaced individual WebSocket session management with `SessionHandler` to centralize and streamline logic. Improved code readability, reliability, and maintainability by reducing redundancy and encapsulating session and request handling in dedicated classes. Added retry mechanisms, acknowledgment handling, and better application shutdown handling.
This commit is contained in:
dalbodeule 2025-04-24 17:45:25 +09:00
parent c5a98943c0
commit aa95976005
No known key found for this signature in database
GPG Key ID: EFA860D069C9FA65

View File

@ -1,252 +1,258 @@
package space.mori.chzzk_bot.webserver.routes package space.mori.chzzk_bot.webserver.routes
import io.ktor.client.plugins.websocket.WebSocketException import io.ktor.client.plugins.websocket.WebSocketException
import io.ktor.server.application.ApplicationStopped import io.ktor.server.application.*
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.sessions.* import io.ktor.server.sessions.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger
import io.ktor.utils.io.CancellationException
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.IOException
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.SongList
import space.mori.chzzk_bot.common.models.User import space.mori.chzzk_bot.common.models.User
import space.mori.chzzk_bot.common.services.SongConfigService
import space.mori.chzzk_bot.common.services.SongListService import space.mori.chzzk_bot.common.services.SongListService
import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.utils.CurrentSong import space.mori.chzzk_bot.webserver.utils.CurrentSong
import java.io.IOException
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongListRoutes() { fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes") val logger = LoggerFactory.getLogger("WSSongListRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val sessionMutex = Mutex() val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
// Manage all active sessions
val sessionHandlers = ConcurrentHashMap<String, SessionHandler>()
// Handle application shutdown
environment.monitor.subscribe(ApplicationStopped) { environment.monitor.subscribe(ApplicationStopped) {
songListScope.cancel() sessionHandlers.values.forEach {
}
val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>()
suspend fun addSession(uid: String, session: WebSocketServerSession) {
val oldSession = sessionMutex.withLock {
val old = sessions[uid]
sessions[uid] = session
old
}
if(oldSession != null) {
songListScope.launch { songListScope.launch {
try { it.close(CloseReason(CloseReason.Codes.NORMAL, "Server shutting down"))
oldSession.close(CloseReason(
CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
} catch(e: Exception) {
logger.warn("Error closing old session: ${e.message}")
}
} }
} }
} }
suspend fun removeSession(uid: String) { // WebSocket endpoint
sessionMutex.withLock {
sessions.remove(uid)
}
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedUid: String): Boolean {
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[expectedUid] = ackDeferred
return try {
withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false
} catch (e: CancellationException) {
false
} finally {
ackMap.remove(expectedUid)
}
}
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] }
if (ws == null) {
logger.debug("No active session for $uid. Retrying in $delayMillis ms.")
delay(delayMillis)
attempt++
continue
}
try {
ws.sendSerialized(res)
val ackReceived = waitForAck(ws, res.uid)
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
sentSuccessfully = true
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
attempt++
}
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException || e is IOException) {
removeSession(uid)
}
}
if (!sentSuccessfully && attempt < maxRetries) {
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
}
webSocket("/songlist") { webSocket("/songlist") {
val session = call.sessions.get<UserSession>() val session = call.sessions.get<UserSession>()
val user = session?.id?.let { UserService.getUser(it) } val user: User? = session?.id?.let { UserService.getUser(it) }
if (user == null) { if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID"))
return@webSocket return@webSocket
} }
val uid = user.token val uid = user.token
addSession(uid, this)
if (status[uid] == SongType.STREAM_OFF) { // Ensure only one session per user
songListScope.launch { sessionHandlers[uid]?.close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
try {
sendSerialized(SongResponse( val handler = SessionHandler(uid, this, dispatcher, logger)
SongType.STREAM_OFF.value, sessionHandlers[uid] = handler
uid,
null, // Initialize session
null, handler.initialize()
null,
)) // Listen for incoming frames
} catch (e: Exception) {
logger.warn("Error sending STREAM_OFF: ${e.message}")
} finally {
removeSession(uid)
}
}
return@webSocket
}
try { try {
for (frame in incoming) { for (frame in incoming) {
when (frame) { when (frame) {
is Frame.Text -> { is Frame.Text -> handler.handleTextFrame(frame.readText())
val text = frame.readText()
if (text.trim() == "ping") {
send("pong")
} else {
val data = Json.decodeFromString<SongRequest>(text)
if (data.type == SongType.ACK.value) {
ackMap[data.uid]?.complete(true)
} else {
handleSongRequest(data, user, dispatcher, logger)
}
}
}
is Frame.Ping -> send(Frame.Pong(frame.data)) is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {} else -> Unit
} }
} }
} catch (e: ClosedReceiveChannelException) { } catch (e: ClosedReceiveChannelException) {
logger.error("WebSocket connection closed: ${e.message}") logger.info("Session closed: ${e.message}")
} catch (e: IOException) {
logger.error("IO error: ${e.message}")
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error in WebSocket: ${e.message}") logger.error("Unexpected error: ${e.message}")
} finally { } finally {
removeSession(uid) sessionHandlers.remove(uid)
ackMap.remove(uid) handler.close(CloseReason(CloseReason.Codes.NORMAL, "Session ended"))
} }
} }
dispatcher.subscribe(SongEvent::class) { // Subscribe to SongEvents
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) dispatcher.subscribe(SongEvent::class) { event ->
val handler = sessionHandlers[event.uid]
songListScope.launch { songListScope.launch {
try { handler?.sendSongResponse(event)
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
it.reqUid,
it.current?.toSerializable(),
it.next?.toSerializable(),
it.delUrl
)
)
}
} catch(e: Exception) {
logger.error("Error handling song event: ${e.message}")
} }
} }
}
dispatcher.subscribe(TimerEvent::class) { // Subscribe to TimerEvents
if (it.type == TimerType.STREAM_OFF) { dispatcher.subscribe(TimerEvent::class) { event ->
if (event.type == TimerType.STREAM_OFF) {
val handler = sessionHandlers[event.uid]
songListScope.launch { songListScope.launch {
try { handler?.sendTimerOff()
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
null,
null,
null,
)
)
}
} catch(e: Exception) {
logger.error("Error handling timer event: ${e.message}")
}
} }
} }
} }
} }
// 노래 처리를 위한 Mutex 추가
private val songMutex = Mutex() class SessionHandler(
fun handleSongRequest( private val uid: String,
private val session: WebSocketServerSession,
private val dispatcher: CoroutinesEventBus,
private val logger: Logger
) {
private val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>()
private val sessionMutex = Mutex()
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
suspend fun initialize() {
// Send initial status if needed,
// For example, send STREAM_OFF if applicable
// This can be extended based on your requirements
}
suspend fun handleTextFrame(text: String) {
if (text.trim() == "ping") {
session.send("pong")
return
}
val data = try {
Json.decodeFromString<SongRequest>(text)
} catch (e: Exception) {
logger.warn("Failed to decode SongRequest: ${e.message}")
return
}
when (data.type) {
SongType.ACK.value -> handleAck(data.uid)
else -> handleSongRequest(data)
}
}
private fun handleAck(requestUid: String) {
ackMap[requestUid]?.complete(true)
ackMap.remove(requestUid)
}
private fun handleSongRequest(data: SongRequest) {
scope.launch {
SongRequestProcessor.process(data, uid, dispatcher, this@SessionHandler, logger)
}
}
suspend fun sendSongResponse(event: SongEvent) {
val response = SongResponse(
type = event.type.value,
uid = event.uid,
reqUid = event.reqUid,
current = event.current?.toSerializable(),
next = event.next?.toSerializable(),
delUrl = event.delUrl
)
sendWithRetry(response)
}
suspend fun sendTimerOff() {
val response = SongResponse(
type = TimerType.STREAM_OFF.value,
uid = uid,
reqUid = null,
current = null,
next = null,
delUrl = null
)
sendWithRetry(response)
}
private suspend fun sendWithRetry(res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(res)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[res.uid] = ackDeferred
val ackReceived = withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
return
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
}
} catch (e: IOException) {
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException) {
close(CloseReason(CloseReason.Codes.PROTOCOL_ERROR, "WebSocket error"))
return
}
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.warn("Unexpected error while sending message to $uid on attempt $attempt: ${e.message}")
}
attempt++
delay(delayMillis)
}
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
suspend fun close(reason: CloseReason) {
try {
session.close(reason)
} catch (e: Exception) {
logger.warn("Error closing session: ${e.message}")
}
}
}
object SongRequestProcessor {
private val songMutex = Mutex()
suspend fun process(
data: SongRequest,
uid: String,
dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger
) {
val user = UserService.getUser(uid) ?: return
when (data.type) {
SongType.ADD.value -> handleAdd(data, user, dispatcher, handler, logger)
SongType.REMOVE.value -> handleRemove(data, user, dispatcher, logger)
SongType.NEXT.value -> handleNext(user, dispatcher, logger)
else -> {
// Handle other types if necessary
}
}
}
private suspend fun handleAdd(
data: SongRequest, data: SongRequest,
user: User, user: User,
dispatcher: CoroutinesEventBus, dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger logger: Logger
) { ) {
if (data.maxQueue != null && data.maxQueue > 0) SongConfigService.updateQueueLimit(user, data.maxQueue) val url = data.url ?: return
if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit) val youtubeVideo = getYoutubeVideo(url) ?: run {
if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly) logger.warn("Failed to fetch YouTube video for URL: $url")
if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled) return
when (data.type) { }
SongType.ADD.value -> {
data.url?.let { url ->
try {
val youtubeVideo = getYoutubeVideo(url)
if (youtubeVideo != null) {
songListScope.launch {
songMutex.withLock { songMutex.withLock {
SongListService.saveSong( SongListService.saveSong(
user, user,
@ -257,87 +263,95 @@ fun handleSongRequest(
youtubeVideo.length, youtubeVideo.length,
user.username user.username
) )
}
dispatcher.post( dispatcher.post(
SongEvent( SongEvent(
user.token, uid = user.token,
SongType.ADD, type = SongType.ADD,
user.token, reqUid = user.token,
CurrentSong.getSong(user), current = CurrentSong.getSong(user),
youtubeVideo next = youtubeVideo
) )
) )
} }
}
} private suspend fun handleRemove(
} catch (e: Exception) { data: SongRequest,
logger.debug("SongType.ADD Error: ${user.token} $e") user: User,
} dispatcher: CoroutinesEventBus,
} logger: Logger
} ) {
SongType.REMOVE.value -> { val url = data.url ?: return
data.url?.let { url ->
songListScope.launch {
songMutex.withLock { songMutex.withLock {
val songs = SongListService.getSong(user) val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url } val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) { if (exactSong != null) {
SongListService.deleteSong(user, exactSong.uid, exactSong.name) SongListService.deleteSong(user, exactSong.uid, exactSong.name)
} }
}
dispatcher.post( dispatcher.post(
SongEvent( SongEvent(
user.token, uid = user.token,
SongType.REMOVE, type = SongType.REMOVE,
null, delUrl = url,
null, reqUid = null,
null, current = null,
url next = null,
) )
) )
} }
}
} private suspend fun handleNext(
} user: User,
SongType.NEXT.value -> { dispatcher: CoroutinesEventBus,
songListScope.launch { logger: Logger
songMutex.withLock { ) {
val songList = SongListService.getSong(user)
var song: SongList? = null var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null var youtubeVideo: YoutubeVideo? = null
songMutex.withLock {
val songList = SongListService.getSong(user)
if (songList.isNotEmpty()) { if (songList.isNotEmpty()) {
song = songList[0] song = songList[0]
SongListService.deleteSong(user, song.uid, song.name) SongListService.deleteSong(user, song.uid, song.name)
} }
}
song?.let { song?.let {
youtubeVideo = YoutubeVideo( youtubeVideo = YoutubeVideo(
song.url, it.url,
song.name, it.name,
song.author, it.author,
song.time it.time
) )
} }
dispatcher.post( dispatcher.post(
SongEvent( SongEvent(
user.token, uid = user.token,
SongType.NEXT, type = SongType.NEXT,
song?.uid, current = null,
youtubeVideo next = youtubeVideo,
reqUid = null,
delUrl = null
) )
) )
CurrentSong.setSong(user, youtubeVideo) CurrentSong.setSong(user, youtubeVideo)
} }
}
}
}
} }
@Serializable @Serializable
data class SongRequest( data class SongRequest(
val type: Int, val type: Int,
val uid: String, val uid: String,
val url: String?, val url: String? = null,
val maxQueue: Int?, val maxQueue: Int? = null,
val maxUserLimit: Int?, val maxUserLimit: Int? = null,
val isStreamerOnly: Boolean?, val isStreamerOnly: Boolean? = null,
val remove: Int?, val remove: Int? = null,
val isDisabled: Boolean?, val isDisabled: Boolean? = null
) )