Refactor WebSocket handlers and add ACK-based message flow

Consolidated coroutine scopes into `songListScope` and `timerScope` for better management across WebSocket routes. Introduced ACK (acknowledgment) handling for reliable message delivery with retries and timeouts. Updated session handling for multiple WebSocket routes to improve code maintainability and consistency.
This commit is contained in:
dalbodeule 2025-04-24 16:56:49 +09:00
parent d07cdb6ae8
commit 8230762053
No known key found for this signature in database
GPG Key ID: EFA860D069C9FA65
5 changed files with 190 additions and 161 deletions

View File

@ -5,7 +5,8 @@ enum class TimerType(var value: Int) {
TIMER(1),
REMOVE(2),
STREAM_OFF(50)
STREAM_OFF(50),
ACK(51)
}
class TimerEvent(

View File

@ -22,11 +22,9 @@ import io.ktor.server.websocket.*
import kotlinx.coroutines.delay
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.webserver.routes.*
import space.mori.chzzk_bot.webserver.utils.DiscordRatelimits
import wsSongListRoutes
import java.math.BigInteger
import java.security.SecureRandom
import java.time.Duration

View File

@ -1,3 +1,5 @@
package space.mori.chzzk_bot.webserver.routes
import io.ktor.client.plugins.websocket.WebSocketException
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.*
@ -6,7 +8,7 @@ import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger
import io.ktor.utils.io.CancellationException
import io.ktor.websocket.*
import io.ktor.websocket.Frame.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
@ -14,8 +16,10 @@ import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.IOException
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
@ -30,35 +34,29 @@ import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.routes.SongResponse
import space.mori.chzzk_bot.webserver.routes.toSerializable
import space.mori.chzzk_bot.webserver.utils.CurrentSong
import java.util.concurrent.ConcurrentHashMap
val routeScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
// 세션 관련 작업을 위한 Mutex 추가
val sessionMutex = Mutex()
environment.monitor.subscribe(ApplicationStopped) {
routeScope.cancel()
songListScope.cancel()
}
suspend fun addSession(uid: String, session: WebSocketServerSession) {
val oldSession = sessionMutex.withLock {
val old = sessions[uid]
sessions[uid] = session
old
}
if(oldSession != null) {
routeScope.launch {
songListScope.launch {
try {
oldSession.close(CloseReason(
CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
@ -69,87 +67,54 @@ fun Routing.wsSongListRoutes() {
}
}
}
suspend fun removeSession(uid: String) {
sessionMutex.withLock {
sessions.remove(uid)
}
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedType: Int): Boolean {
try {
for (frame in ws.incoming) {
if (frame is Text) {
val message = frame.readText()
if (message == "ping") {
continue // Keep the loop running if a ping is received
}
val data = Json.decodeFromString<SongRequest>(message)
if (data.type == SongType.ACK.value) {
return true // ACK received
}
}
}
} catch (e: Exception) {
logger.warn("Error waiting for ACK: ${e.message}")
}
return false // Return false if no ACK received
}
val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>()
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] } ?: run {
while (attempt < maxRetries) {
val ws: WebSocketServerSession? = sessionMutex.withLock { sessions[uid] }
if (ws == null) {
logger.debug("No active session for $uid. Retrying in $delayMillis ms.")
delay(delayMillis)
attempt++
null
continue
}
if (ws == null) continue
try {
// 메시지 전송 시도
ws.sendSerialized(res)
logger.debug("Message sent successfully to $uid on attempt $attempt")
// ACK 대기
val ackReceived = waitForAck(ws, res.type)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[res.uid] = ackDeferred
val ackReceived = withTimeoutOrNull(delayMillis) { ackDeferred.await() } ?: false
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
sentSuccessfully = true
return
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
attempt++
}
} catch (e: CancellationException) {
// 코루틴 취소는 다시 throw
throw e
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}")
if (e is WebSocketException || e is IOException) {
logger.warn("Connection issue detected, session may be invalid")
// 연결 문제로 보이면 세션을 제거할 수도 있음
removeSession(uid)
}
}
if (!sentSuccessfully && attempt < maxRetries) {
if (attempt < maxRetries) {
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
webSocket("/songlist") {
val session = call.sessions.get<UserSession>()
val user = session?.id?.let { UserService.getUser(it) }
@ -157,13 +122,10 @@ fun Routing.wsSongListRoutes() {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID"))
return@webSocket
}
val uid = user.token
addSession(uid, this)
if (status[uid] == SongType.STREAM_OFF) {
routeScope.launch {
songListScope.launch {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
@ -175,23 +137,32 @@ fun Routing.wsSongListRoutes() {
removeSession(uid)
}
try {
songListScope.launch {
for (frame in incoming) {
when (frame) {
is Text -> {
is Frame.Text -> {
val text = frame.readText()
if (text.trim() == "ping") {
send("pong")
} else {
val data = Json.decodeFromString<SongRequest>(text)
// Handle song requests
handleSongRequest(data, user, dispatcher, logger)
if (data.type == SongType.ACK.value) {
ackMap[data.uid]?.complete(true)
ackMap.remove(data.uid)
} else {
handleSongRequest(data, user, dispatcher, logger)
}
}
}
is Ping -> send(Pong(frame.data))
else -> ""
is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {}
}
}
}
try {
// Keep the connection alive
suspendCancellableCoroutine<Unit> {}
} catch (e: ClosedReceiveChannelException) {
logger.error("WebSocket connection closed: ${e.message}")
} catch(e: Exception) {
@ -203,7 +174,7 @@ fun Routing.wsSongListRoutes() {
dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
routeScope.launch {
songListScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
@ -223,10 +194,9 @@ fun Routing.wsSongListRoutes() {
}
}
}
dispatcher.subscribe(TimerEvent::class) {
if (it.type == TimerType.STREAM_OFF) {
routeScope.launch {
songListScope.launch {
try {
val user = UserService.getUser(it.uid)
if (user != null) {
@ -247,10 +217,8 @@ fun Routing.wsSongListRoutes() {
}
}
}
// 노래 처리를 위한 Mutex 추가
private val songMutex = Mutex()
suspend fun handleSongRequest(
data: SongRequest,
user: User,
@ -261,14 +229,13 @@ suspend fun handleSongRequest(
if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit)
if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly)
if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled)
when (data.type) {
SongType.ADD.value -> {
data.url?.let { url ->
try {
val youtubeVideo = getYoutubeVideo(url)
if (youtubeVideo != null) {
routeScope.launch {
songListScope.launch {
songMutex.withLock {
SongListService.saveSong(
user,
@ -298,7 +265,7 @@ suspend fun handleSongRequest(
}
SongType.REMOVE.value -> {
data.url?.let { url ->
routeScope.launch {
songListScope.launch {
songMutex.withLock {
val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url }
@ -320,17 +287,15 @@ suspend fun handleSongRequest(
}
}
SongType.NEXT.value -> {
routeScope.launch {
songListScope.launch {
songMutex.withLock {
val songList = SongListService.getSong(user)
var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null
if (songList.isNotEmpty()) {
song = songList[0]
SongListService.deleteSong(user, song.uid, song.name)
}
song?.let {
youtubeVideo = YoutubeVideo(
song.url,
@ -347,7 +312,6 @@ suspend fun handleSongRequest(
youtubeVideo
)
)
CurrentSong.setSong(user, youtubeVideo)
}
}

View File

@ -4,6 +4,7 @@ import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.*
import io.ktor.server.websocket.*
import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
@ -11,7 +12,9 @@ import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.*
@ -20,21 +23,20 @@ import space.mori.chzzk_bot.common.utils.YoutubeVideo
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
val routeScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
val songScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongRoutes() {
environment.monitor.subscribe(ApplicationStopped) {
routeScope.cancel()
songListScope.cancel()
}
val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val ackMap = ConcurrentHashMap<String, ConcurrentHashMap<WebSocketServerSession, CompletableDeferred<Boolean>>>()
fun addSession(uid: String, session: WebSocketServerSession) {
sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session)
}
fun removeSession(uid: String, session: WebSocketServerSession) {
sessions[uid]?.remove(session)
if(sessions[uid]?.isEmpty() == true) {
@ -51,27 +53,35 @@ fun Routing.wsSongRoutes() {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(message) // 메시지 전송 시도
return true // 성공하면 true 반환
session.sendSerialized(message)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap.computeIfAbsent(message.uid) { ConcurrentHashMap() }[session] = ackDeferred
val ackReceived = withTimeoutOrNull(delayMillis) { ackDeferred.await() } ?: false
if (ackReceived) {
ackMap[message.uid]?.remove(session)
return true
} else {
attempt++
logger.warn("ACK not received for message to ${message.uid} on attempt $attempt.")
}
} catch (e: Exception) {
attempt++
logger.info("Failed to send message on attempt $attempt. Retrying in $delayMillis ms.")
e.printStackTrace()
delay(delayMillis) // 재시도 전 대기
delay(delayMillis)
}
}
return false // 재시도 실패 시 false 반환
return false
}
fun broadcastMessage(userId: String, message: SongResponse) {
val userSessions = sessions[userId]
userSessions?.forEach { session ->
routeScope.launch {
songListScope.launch {
val success = sendWithRetry(session, message)
if (!success) {
println("Removing session for user $userId due to repeated failures.")
userSessions.remove(session) // 실패 시 세션 제거
logger.info("Removing session for user $userId due to repeated failures.")
removeSession(userId, session)
}
}
}
@ -80,19 +90,13 @@ fun Routing.wsSongRoutes() {
webSocket("/song/{uid}") {
val uid = call.parameters["uid"]
val user = uid?.let { UserService.getUser(it) }
if (uid == null) {
if (uid == null || user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
addSession(uid, this)
if(status[uid] == SongType.STREAM_OFF) {
routeScope.launch {
songListScope.launch {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
@ -102,33 +106,36 @@ fun Routing.wsSongRoutes() {
))
}
}
try {
for (frame in incoming) {
when(frame) {
is Frame.Text -> {
if(frame.readText().trim() == "ping") {
val text = frame.readText().trim()
if(text == "ping") {
send("pong")
} else {
val data = Json.decodeFromString<SongRequest>(text)
if (data.type == SongType.ACK.value) {
ackMap[data.uid]?.get(this)?.complete(true)
ackMap[data.uid]?.remove(this)
}
}
}
is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {
}
else -> {}
}
}
} catch(e: ClosedReceiveChannelException) {
logger.error("Error in WebSocket: ${e.message}")
} finally {
removeSession(uid, this)
ackMap[uid]?.remove(this)
}
}
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
routeScope.launch {
songListScope.launch {
broadcastMessage(it.uid, SongResponse(
it.type.value,
it.uid,
@ -141,7 +148,7 @@ fun Routing.wsSongRoutes() {
}
dispatcher.subscribe(TimerEvent::class) {
if(it.type == TimerType.STREAM_OFF) {
routeScope.launch {
songListScope.launch {
broadcastMessage(it.uid, SongResponse(
it.type.value,
it.uid,
@ -153,7 +160,6 @@ fun Routing.wsSongRoutes() {
}
}
}
@Serializable
data class SerializableYoutubeVideo(
val url: String,
@ -161,9 +167,7 @@ data class SerializableYoutubeVideo(
val author: String,
val length: Int
)
fun YoutubeVideo.toSerializable() = SerializableYoutubeVideo(url, name, author, length)
@Serializable
data class SongResponse(
val type: Int,
@ -172,4 +176,4 @@ data class SongResponse(
val current: SerializableYoutubeVideo? = null,
val next: SerializableYoutubeVideo? = null,
val delUrl: String? = null
)
)

View File

@ -1,13 +1,20 @@
package space.mori.chzzk_bot.webserver.routes
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.*
import io.ktor.server.websocket.*
import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.*
@ -17,14 +24,19 @@ import space.mori.chzzk_bot.webserver.utils.CurrentTimer
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
val timerScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsTimerRoutes() {
environment.monitor.subscribe(ApplicationStopped) {
songListScope.cancel()
}
val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>()
val logger = LoggerFactory.getLogger("WSTimerRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val ackMap = ConcurrentHashMap<String, ConcurrentHashMap<WebSocketServerSession, CompletableDeferred<Boolean>>>()
fun addSession(uid: String, session: WebSocketServerSession) {
sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session)
}
fun removeSession(uid: String, session: WebSocketServerSession) {
sessions[uid]?.remove(session)
if(sessions[uid]?.isEmpty() == true) {
@ -32,82 +44,132 @@ fun Routing.wsTimerRoutes() {
}
}
webSocket("/timer/{uid}") {
val uid = call.parameters["uid"]
val user = uid?.let { UserService.getUser(it) }
if (uid == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
addSession(uid, this)
val timer = CurrentTimer.getTimer(user)
if(timer?.type == TimerType.STREAM_OFF) {
CoroutineScope(Dispatchers.Default).launch {
sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null))
}
} else {
CoroutineScope(Dispatchers.Default).launch {
if (timer == null) {
sendSerialized(
TimerResponse(
TimerConfigService.getConfig(user)?.option ?: TimerType.REMOVE.value,
null
)
)
suspend fun sendWithRetry(
session: WebSocketServerSession,
message: TimerResponse,
maxRetries: Int = 3,
delayMillis: Long = 2000L
): Boolean {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(message)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap.computeIfAbsent(message.uid) { ConcurrentHashMap() }[session] = ackDeferred
val ackReceived = withTimeoutOrNull(delayMillis) { ackDeferred.await() } ?: false
if (ackReceived) {
ackMap[message.uid]?.remove(session)
return true
} else {
sendSerialized(
TimerResponse(
timer.type.value,
timer.time
)
)
attempt++
logger.warn("ACK not received for message to ${message.uid} on attempt $attempt.")
}
} catch (e: Exception) {
attempt++
logger.info("Failed to send message on attempt $attempt. Retrying in $delayMillis ms.")
e.printStackTrace()
delay(delayMillis)
}
}
return false
}
fun broadcastMessage(uid: String, message: TimerResponse) {
val userSessions = sessions[uid]
userSessions?.forEach { session ->
songListScope.launch {
val success = sendWithRetry(session, message.copy(uid = uid))
if (!success) {
logger.info("Removing session for user $uid due to repeated failures.")
removeSession(uid, session)
}
}
}
}
webSocket("/timer/{uid}") {
val uid = call.parameters["uid"]
val user = uid?.let { UserService.getUser(it) }
if (uid == null || user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
addSession(uid, this)
val timer = CurrentTimer.getTimer(user)
if (timer?.type == TimerType.STREAM_OFF) {
songListScope.launch {
sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null, uid))
}
} else {
songListScope.launch {
if(timer?.type == TimerType.STREAM_OFF) {
sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null, uid))
} else {
if (timer == null) {
sendSerialized(
TimerResponse(
TimerConfigService.getConfig(user)?.option ?: TimerType.REMOVE.value,
null,
uid
)
)
} else {
sendSerialized(
TimerResponse(
timer.type.value,
timer.time,
uid
)
)
}
}
}
}
try {
for (frame in incoming) {
when(frame) {
is Frame.Text -> {
if(frame.readText().trim() == "ping") {
val text = frame.readText().trim()
if(text == "ping") {
send("pong")
} else {
val data = Json.decodeFromString<TimerRequest>(text)
if (data.type == TimerType.ACK.value) {
ackMap[data.uid]?.get(this)?.complete(true)
ackMap[data.uid]?.remove(this)
}
}
}
is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {
}
else -> {}
}
}
} catch(e: ClosedReceiveChannelException) {
logger.error("Error in WebSocket: ${e.message}")
} finally {
removeSession(uid, this)
ackMap[uid]?.remove(this)
}
}
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
dispatcher.subscribe(TimerEvent::class) {
logger.debug("TimerEvent: {} / {}", it.uid, it.type)
val user = UserService.getUser(it.uid)
CurrentTimer.setTimer(user!!, it)
CoroutineScope(Dispatchers.Default).launch {
sessions[it.uid]?.forEach { ws ->
ws.sendSerialized(TimerResponse(it.type.value, it.time ?: ""))
}
songListScope.launch {
broadcastMessage(it.uid, TimerResponse(it.type.value, it.time ?: "", it.uid))
}
}
}
@Serializable
data class TimerResponse(
val type: Int,
val time: String?
val time: String?,
val uid: String
)
@Serializable
data class TimerRequest(
val type: Int,
val uid: String
)