Refactor WebSocket route to use shared CoroutineScope

Introduced a shared `routeScope` with `SupervisorJob` for better coroutine management across WebSocket routes. This replaces ad-hoc CoroutineScope creation, preventing unnecessary scope overhead and supporting centralized cancellation. Mutexes were added for session and song-related operations to ensure thread safety.
This commit is contained in:
dalbodeule 2025-04-24 16:23:55 +09:00
parent 7a84a9e437
commit 5a7f78ff3e
No known key found for this signature in database
GPG Key ID: EFA860D069C9FA65
2 changed files with 285 additions and 91 deletions

View File

@ -1,19 +1,29 @@
import io.ktor.client.plugins.websocket.WebSocketException
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.sessions.* import io.ktor.server.sessions.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger import io.ktor.util.logging.Logger
import io.ktor.utils.io.CancellationException
import io.ktor.websocket.* import io.ktor.websocket.*
import io.ktor.server.application.*
import io.ktor.websocket.Frame.* import io.ktor.websocket.Frame.*
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.IOException
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.SongList
import space.mori.chzzk_bot.common.models.SongLists.uid
import space.mori.chzzk_bot.common.models.User import space.mori.chzzk_bot.common.models.User
import space.mori.chzzk_bot.common.services.SongConfigService import space.mori.chzzk_bot.common.services.SongConfigService
import space.mori.chzzk_bot.common.services.SongListService import space.mori.chzzk_bot.common.services.SongListService
@ -21,56 +31,228 @@ import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.routes.SongResponse
import space.mori.chzzk_bot.webserver.routes.toSerializable
import space.mori.chzzk_bot.webserver.utils.CurrentSong import space.mori.chzzk_bot.webserver.utils.CurrentSong
import space.mori.chzzk_bot.webserver.utils.SongListWebSocketManager import java.util.concurrent.ConcurrentHashMap
fun Route.wsSongListRoutes(songListWebSocketManager: SongListWebSocketManager) { val routeScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
// 세션 관련 작업을 위한 Mutex 추가
val sessionMutex = Mutex()
environment.monitor.subscribe(ApplicationStopped) {
routeScope.cancel()
}
suspend fun addSession(uid: String, session: WebSocketServerSession) {
val oldSession = sessionMutex.withLock {
val old = sessions[uid]
sessions[uid] = session
old
}
if(oldSession != null) {
routeScope.launch {
try {
oldSession.close(CloseReason(
CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
} catch(e: Exception) {
logger.warn("Error closing old session: ${e.message}")
e.printStackTrace()
}
}
}
}
suspend fun removeSession(uid: String) {
sessionMutex.withLock {
sessions.remove(uid)
}
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedType: Int): Boolean {
return withTimeoutOrNull(5000L) { // 5초 타임아웃
try {
for (frame in ws.incoming) {
if (frame is Text) {
val message = frame.readText()
if(message == "ping") {
return@withTimeoutOrNull true
}
val data = Json.decodeFromString<SongRequest>(message)
if (data.type == SongType.ACK.value) {
return@withTimeoutOrNull true // ACK 받음
}
}
}
false // 채널이 닫힘
} catch (e: Exception) {
logger.warn("Error waiting for ACK: ${e.message}")
false
}
} ?: false // 타임아웃 시 false 반환
}
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] } ?: run {
logger.debug("No active session for $uid. Retrying in $delayMillis ms.")
delay(delayMillis)
attempt++
null
}
if(ws == null) continue
try {
// 메시지 전송 시도
ws.sendSerialized(res)
logger.debug("Message sent successfully to $uid on attempt $attempt")
// ACK 대기
val ackReceived = waitForAck(ws, res.type)
if (ackReceived) {
sentSuccessfully = true
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
attempt++
}
} catch (e: CancellationException) {
// 코루틴 취소는 다시 throw
throw e
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException || e is IOException) {
logger.warn("Connection issue detected, session may be invalid")
// 연결 문제로 보이면 세션을 제거할 수도 있음
removeSession(uid)
}
}
if (!sentSuccessfully && attempt < maxRetries) {
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
}
webSocket("/songlist") { webSocket("/songlist") {
val session = call.sessions.get<UserSession>() val session = call.sessions.get<UserSession>()
val uid = session?.id val user = session?.id?.let { UserService.getUser(it) }
if (uid == null) { if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid session")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID"))
return@webSocket return@webSocket
} }
val uid = user.token
addSession(uid, this)
if (status[uid] == SongType.STREAM_OFF) {
routeScope.launch {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
null,
null,
null,
))
}
removeSession(uid)
}
try { try {
songListWebSocketManager.addSession(uid, this)
for (frame in incoming) { for (frame in incoming) {
when (frame) { when (frame) {
is Text -> { is Text -> {
val text = frame.readText().trim() val text = frame.readText()
if (text == SongListWebSocketManager.PING_MESSAGE) { if (text.trim() == "ping") {
send(SongListWebSocketManager.PONG_MESSAGE) send("pong")
songListWebSocketManager.handlePong(uid)
} else { } else {
val data = Json.decodeFromString<SongRequest>(text)
// Handle song requests // Handle song requests
text.let { Json.decodeFromString<SongRequest>(it) }.let { data -> handleSongRequest(data, user, dispatcher, logger)
val user = session.id.let { UserService.getUser(it) }
if(user == null) {
songListWebSocketManager.removeSession(uid)
return@webSocket
}
handleSongRequest(data, user, dispatcher, songListWebSocketManager.logger)
}.runCatching { songListWebSocketManager.logger.error("Failed to parse WebSocket message as SongRequest.") }
} }
} }
is Ping -> send(Pong(frame.data)) is Ping -> send(Pong(frame.data))
else -> songListWebSocketManager.logger.warn("Unsupported frame type received.") else -> ""
} }
} }
} catch (e: ClosedReceiveChannelException) {
logger.error("WebSocket connection closed: ${e.message}")
} catch(e: Exception) {
logger.error("Error in WebSocket: ${e.message}")
} finally {
removeSession(uid)
}
}
dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
routeScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
it.reqUid,
it.current?.toSerializable(),
it.next?.toSerializable(),
it.delUrl
)
)
}
} catch(e: Exception) { } catch(e: Exception) {
songListWebSocketManager.logger.error("WebSocket error: ${e.message}") logger.error("Error handling song event: ${e.message}")
} finally {
songListWebSocketManager.removeSession(uid)
} }
} }
} }
dispatcher.subscribe(TimerEvent::class) {
if (it.type == TimerType.STREAM_OFF) {
routeScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
null,
null,
null,
)
)
}
} catch(e: Exception) {
logger.error("Error handling timer event: ${e.message}")
}
}
}
}
}
// 노래 처리를 위한 Mutex 추가
private val songMutex = Mutex()
suspend fun handleSongRequest( suspend fun handleSongRequest(
data: SongRequest, data: SongRequest,
@ -89,7 +271,8 @@ suspend fun handleSongRequest(
try { try {
val youtubeVideo = getYoutubeVideo(url) val youtubeVideo = getYoutubeVideo(url)
if (youtubeVideo != null) { if (youtubeVideo != null) {
CoroutineScope(Dispatchers.Default).launch { routeScope.launch {
songMutex.withLock {
SongListService.saveSong( SongListService.saveSong(
user, user,
user.token, user.token,
@ -110,13 +293,16 @@ suspend fun handleSongRequest(
) )
} }
} }
}
} catch (e: Exception) { } catch (e: Exception) {
logger.debug("SongType.ADD Error: $uid $e") logger.debug("SongType.ADD Error: ${user.token} $e")
} }
} }
} }
SongType.REMOVE.value -> { SongType.REMOVE.value -> {
data.url?.let { url -> data.url?.let { url ->
routeScope.launch {
songMutex.withLock {
val songs = SongListService.getSong(user) val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url } val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) { if (exactSong != null) {
@ -134,7 +320,11 @@ suspend fun handleSongRequest(
) )
} }
} }
}
}
SongType.NEXT.value -> { SongType.NEXT.value -> {
routeScope.launch {
songMutex.withLock {
val songList = SongListService.getSong(user) val songList = SongListService.getSong(user)
var song: SongList? = null var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null var youtubeVideo: YoutubeVideo? = null
@ -165,7 +355,8 @@ suspend fun handleSongRequest(
} }
} }
} }
}
}
@Serializable @Serializable
data class SongRequest( data class SongRequest(

View File

@ -5,6 +5,7 @@ import io.ktor.server.websocket.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -17,6 +18,8 @@ import space.mori.chzzk_bot.common.utils.YoutubeVideo
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ConcurrentLinkedQueue
val routeScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongRoutes() { fun Routing.wsSongRoutes() {
val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>() val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>()
val status = ConcurrentHashMap<String, SongType>() val status = ConcurrentHashMap<String, SongType>()
@ -58,7 +61,7 @@ fun Routing.wsSongRoutes() {
val userSessions = sessions[userId] val userSessions = sessions[userId]
userSessions?.forEach { session -> userSessions?.forEach { session ->
CoroutineScope(Dispatchers.Default).launch { routeScope.launch {
val success = sendWithRetry(session, message) val success = sendWithRetry(session, message)
if (!success) { if (!success) {
println("Removing session for user $userId due to repeated failures.") println("Removing session for user $userId due to repeated failures.")
@ -83,7 +86,7 @@ fun Routing.wsSongRoutes() {
addSession(uid, this) addSession(uid, this)
if(status[uid] == SongType.STREAM_OFF) { if(status[uid] == SongType.STREAM_OFF) {
CoroutineScope(Dispatchers.Default).launch { routeScope.launch {
sendSerialized(SongResponse( sendSerialized(SongResponse(
SongType.STREAM_OFF.value, SongType.STREAM_OFF.value,
uid, uid,
@ -119,7 +122,7 @@ fun Routing.wsSongRoutes() {
dispatcher.subscribe(SongEvent::class) { dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
CoroutineScope(Dispatchers.Default).launch { routeScope.launch {
broadcastMessage(it.uid, SongResponse( broadcastMessage(it.uid, SongResponse(
it.type.value, it.type.value,
it.uid, it.uid,
@ -132,7 +135,7 @@ fun Routing.wsSongRoutes() {
} }
dispatcher.subscribe(TimerEvent::class) { dispatcher.subscribe(TimerEvent::class) {
if(it.type == TimerType.STREAM_OFF) { if(it.type == TimerType.STREAM_OFF) {
CoroutineScope(Dispatchers.Default).launch { routeScope.launch {
broadcastMessage(it.uid, SongResponse( broadcastMessage(it.uid, SongResponse(
it.type.value, it.type.value,
it.uid, it.uid,