Refactor WebSocket song list handling and improve session logic

Replaced individual WebSocket session management with `SessionHandler` to centralize and streamline logic. Improved code readability, reliability, and maintainability by reducing redundancy and encapsulating session and request handling in dedicated classes. Added retry mechanisms, acknowledgment handling, and better application shutdown handling.
This commit is contained in:
dalbodeule 2025-04-24 17:45:25 +09:00
parent c5a98943c0
commit aa95976005
No known key found for this signature in database
GPG Key ID: EFA860D069C9FA65

View File

@ -1,332 +1,346 @@
package space.mori.chzzk_bot.webserver.routes
import io.ktor.client.plugins.websocket.WebSocketException
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.application.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger
import io.ktor.utils.io.CancellationException
import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.IOException
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.*
import space.mori.chzzk_bot.common.models.SongList
import space.mori.chzzk_bot.common.models.User
import space.mori.chzzk_bot.common.services.SongConfigService
import space.mori.chzzk_bot.common.services.SongListService
import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.utils.CurrentSong
import java.io.IOException
import java.util.concurrent.ConcurrentHashMap
val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val sessionMutex = Mutex()
val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
// Manage all active sessions
val sessionHandlers = ConcurrentHashMap<String, SessionHandler>()
// Handle application shutdown
environment.monitor.subscribe(ApplicationStopped) {
songListScope.cancel()
}
val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>()
suspend fun addSession(uid: String, session: WebSocketServerSession) {
val oldSession = sessionMutex.withLock {
val old = sessions[uid]
sessions[uid] = session
old
}
if(oldSession != null) {
sessionHandlers.values.forEach {
songListScope.launch {
try {
oldSession.close(CloseReason(
CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
} catch(e: Exception) {
logger.warn("Error closing old session: ${e.message}")
}
it.close(CloseReason(CloseReason.Codes.NORMAL, "Server shutting down"))
}
}
}
suspend fun removeSession(uid: String) {
sessionMutex.withLock {
sessions.remove(uid)
}
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedUid: String): Boolean {
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[expectedUid] = ackDeferred
return try {
withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false
} catch (e: CancellationException) {
false
} finally {
ackMap.remove(expectedUid)
}
}
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] }
if (ws == null) {
logger.debug("No active session for $uid. Retrying in $delayMillis ms.")
delay(delayMillis)
attempt++
continue
}
try {
ws.sendSerialized(res)
val ackReceived = waitForAck(ws, res.uid)
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
sentSuccessfully = true
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
attempt++
}
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException || e is IOException) {
removeSession(uid)
}
}
if (!sentSuccessfully && attempt < maxRetries) {
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
}
// WebSocket endpoint
webSocket("/songlist") {
val session = call.sessions.get<UserSession>()
val user = session?.id?.let { UserService.getUser(it) }
val user: User? = session?.id?.let { UserService.getUser(it) }
if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID"))
return@webSocket
}
val uid = user.token
addSession(uid, this)
if (status[uid] == SongType.STREAM_OFF) {
songListScope.launch {
try {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
null,
null,
null,
))
} catch (e: Exception) {
logger.warn("Error sending STREAM_OFF: ${e.message}")
} finally {
removeSession(uid)
}
}
return@webSocket
}
// Ensure only one session per user
sessionHandlers[uid]?.close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
val handler = SessionHandler(uid, this, dispatcher, logger)
sessionHandlers[uid] = handler
// Initialize session
handler.initialize()
// Listen for incoming frames
try {
for (frame in incoming) {
when (frame) {
is Frame.Text -> {
val text = frame.readText()
if (text.trim() == "ping") {
send("pong")
} else {
val data = Json.decodeFromString<SongRequest>(text)
if (data.type == SongType.ACK.value) {
ackMap[data.uid]?.complete(true)
} else {
handleSongRequest(data, user, dispatcher, logger)
}
}
}
is Frame.Text -> handler.handleTextFrame(frame.readText())
is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {}
else -> Unit
}
}
} catch (e: ClosedReceiveChannelException) {
logger.error("WebSocket connection closed: ${e.message}")
logger.info("Session closed: ${e.message}")
} catch (e: IOException) {
logger.error("IO error: ${e.message}")
} catch (e: Exception) {
logger.error("Error in WebSocket: ${e.message}")
logger.error("Unexpected error: ${e.message}")
} finally {
removeSession(uid)
ackMap.remove(uid)
sessionHandlers.remove(uid)
handler.close(CloseReason(CloseReason.Codes.NORMAL, "Session ended"))
}
}
dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
// Subscribe to SongEvents
dispatcher.subscribe(SongEvent::class) { event ->
val handler = sessionHandlers[event.uid]
songListScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
it.reqUid,
it.current?.toSerializable(),
it.next?.toSerializable(),
it.delUrl
)
)
}
} catch(e: Exception) {
logger.error("Error handling song event: ${e.message}")
}
handler?.sendSongResponse(event)
}
}
dispatcher.subscribe(TimerEvent::class) {
if (it.type == TimerType.STREAM_OFF) {
// Subscribe to TimerEvents
dispatcher.subscribe(TimerEvent::class) { event ->
if (event.type == TimerType.STREAM_OFF) {
val handler = sessionHandlers[event.uid]
songListScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
null,
null,
null,
)
)
}
} catch(e: Exception) {
logger.error("Error handling timer event: ${e.message}")
}
handler?.sendTimerOff()
}
}
}
}
// 노래 처리를 위한 Mutex 추가
private val songMutex = Mutex()
fun handleSongRequest(
data: SongRequest,
user: User,
dispatcher: CoroutinesEventBus,
logger: Logger
class SessionHandler(
private val uid: String,
private val session: WebSocketServerSession,
private val dispatcher: CoroutinesEventBus,
private val logger: Logger
) {
if (data.maxQueue != null && data.maxQueue > 0) SongConfigService.updateQueueLimit(user, data.maxQueue)
if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit)
if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly)
if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled)
when (data.type) {
SongType.ADD.value -> {
data.url?.let { url ->
try {
val youtubeVideo = getYoutubeVideo(url)
if (youtubeVideo != null) {
songListScope.launch {
songMutex.withLock {
SongListService.saveSong(
user,
user.token,
url,
youtubeVideo.name,
youtubeVideo.author,
youtubeVideo.length,
user.username
)
dispatcher.post(
SongEvent(
user.token,
SongType.ADD,
user.token,
CurrentSong.getSong(user),
youtubeVideo
)
)
}
}
}
} catch (e: Exception) {
logger.debug("SongType.ADD Error: ${user.token} $e")
private val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>()
private val sessionMutex = Mutex()
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
suspend fun initialize() {
// Send initial status if needed,
// For example, send STREAM_OFF if applicable
// This can be extended based on your requirements
}
suspend fun handleTextFrame(text: String) {
if (text.trim() == "ping") {
session.send("pong")
return
}
val data = try {
Json.decodeFromString<SongRequest>(text)
} catch (e: Exception) {
logger.warn("Failed to decode SongRequest: ${e.message}")
return
}
when (data.type) {
SongType.ACK.value -> handleAck(data.uid)
else -> handleSongRequest(data)
}
}
private fun handleAck(requestUid: String) {
ackMap[requestUid]?.complete(true)
ackMap.remove(requestUid)
}
private fun handleSongRequest(data: SongRequest) {
scope.launch {
SongRequestProcessor.process(data, uid, dispatcher, this@SessionHandler, logger)
}
}
suspend fun sendSongResponse(event: SongEvent) {
val response = SongResponse(
type = event.type.value,
uid = event.uid,
reqUid = event.reqUid,
current = event.current?.toSerializable(),
next = event.next?.toSerializable(),
delUrl = event.delUrl
)
sendWithRetry(response)
}
suspend fun sendTimerOff() {
val response = SongResponse(
type = TimerType.STREAM_OFF.value,
uid = uid,
reqUid = null,
current = null,
next = null,
delUrl = null
)
sendWithRetry(response)
}
private suspend fun sendWithRetry(res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(res)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[res.uid] = ackDeferred
val ackReceived = withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
return
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
}
} catch (e: IOException) {
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException) {
close(CloseReason(CloseReason.Codes.PROTOCOL_ERROR, "WebSocket error"))
return
}
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
logger.warn("Unexpected error while sending message to $uid on attempt $attempt: ${e.message}")
}
attempt++
delay(delayMillis)
}
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
suspend fun close(reason: CloseReason) {
try {
session.close(reason)
} catch (e: Exception) {
logger.warn("Error closing session: ${e.message}")
}
}
}
object SongRequestProcessor {
private val songMutex = Mutex()
suspend fun process(
data: SongRequest,
uid: String,
dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger
) {
val user = UserService.getUser(uid) ?: return
when (data.type) {
SongType.ADD.value -> handleAdd(data, user, dispatcher, handler, logger)
SongType.REMOVE.value -> handleRemove(data, user, dispatcher, logger)
SongType.NEXT.value -> handleNext(user, dispatcher, logger)
else -> {
// Handle other types if necessary
}
}
SongType.REMOVE.value -> {
data.url?.let { url ->
songListScope.launch {
songMutex.withLock {
val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) {
SongListService.deleteSong(user, exactSong.uid, exactSong.name)
}
dispatcher.post(
SongEvent(
user.token,
SongType.REMOVE,
null,
null,
null,
url
)
)
}
}
}
private suspend fun handleAdd(
data: SongRequest,
user: User,
dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger
) {
val url = data.url ?: return
val youtubeVideo = getYoutubeVideo(url) ?: run {
logger.warn("Failed to fetch YouTube video for URL: $url")
return
}
songMutex.withLock {
SongListService.saveSong(
user,
user.token,
url,
youtubeVideo.name,
youtubeVideo.author,
youtubeVideo.length,
user.username
)
}
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.ADD,
reqUid = user.token,
current = CurrentSong.getSong(user),
next = youtubeVideo
)
)
}
private suspend fun handleRemove(
data: SongRequest,
user: User,
dispatcher: CoroutinesEventBus,
logger: Logger
) {
val url = data.url ?: return
songMutex.withLock {
val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) {
SongListService.deleteSong(user, exactSong.uid, exactSong.name)
}
}
SongType.NEXT.value -> {
songListScope.launch {
songMutex.withLock {
val songList = SongListService.getSong(user)
var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null
if (songList.isNotEmpty()) {
song = songList[0]
SongListService.deleteSong(user, song.uid, song.name)
}
song?.let {
youtubeVideo = YoutubeVideo(
song.url,
song.name,
song.author,
song.time
)
}
dispatcher.post(
SongEvent(
user.token,
SongType.NEXT,
song?.uid,
youtubeVideo
)
)
CurrentSong.setSong(user, youtubeVideo)
}
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.REMOVE,
delUrl = url,
reqUid = null,
current = null,
next = null,
)
)
}
private suspend fun handleNext(
user: User,
dispatcher: CoroutinesEventBus,
logger: Logger
) {
var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null
songMutex.withLock {
val songList = SongListService.getSong(user)
if (songList.isNotEmpty()) {
song = songList[0]
SongListService.deleteSong(user, song.uid, song.name)
}
}
song?.let {
youtubeVideo = YoutubeVideo(
it.url,
it.name,
it.author,
it.time
)
}
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.NEXT,
current = null,
next = youtubeVideo,
reqUid = null,
delUrl = null
)
)
CurrentSong.setSong(user, youtubeVideo)
}
}
@ -334,10 +348,10 @@ fun handleSongRequest(
data class SongRequest(
val type: Int,
val uid: String,
val url: String?,
val maxQueue: Int?,
val maxUserLimit: Int?,
val isStreamerOnly: Boolean?,
val remove: Int?,
val isDisabled: Boolean?,
)
val url: String? = null,
val maxQueue: Int? = null,
val maxUserLimit: Int? = null,
val isStreamerOnly: Boolean? = null,
val remove: Int? = null,
val isDisabled: Boolean? = null
)