Add SongListWebSocketManager and refactor WebSocket routes

Introduced SongListWebSocketManager for managing WebSocket sessions, including ping-pong handling and retry mechanisms. Refactored WSSongListRoutes to delegate session management and simplify logic by leveraging the new manager class.
This commit is contained in:
dalbodeule 2025-04-24 15:58:56 +09:00
parent 17d8065a34
commit 02cede87f8
No known key found for this signature in database
GPG Key ID: EFA860D069C9FA65
2 changed files with 299 additions and 150 deletions

View File

@ -3,16 +3,14 @@ import io.ktor.server.sessions.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger import io.ktor.util.logging.Logger
import io.ktor.websocket.* import io.ktor.websocket.*
import io.ktor.server.application.*
import io.ktor.websocket.Frame.* import io.ktor.websocket.Frame.*
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.SongList
import space.mori.chzzk_bot.common.models.SongLists.uid import space.mori.chzzk_bot.common.models.SongLists.uid
@ -23,177 +21,56 @@ import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.routes.SongResponse
import space.mori.chzzk_bot.webserver.routes.toSerializable
import space.mori.chzzk_bot.webserver.utils.CurrentSong import space.mori.chzzk_bot.webserver.utils.CurrentSong
import java.util.concurrent.ConcurrentHashMap import space.mori.chzzk_bot.webserver.utils.SongListWebSocketManager
fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes")
fun Route.wsSongListRoutes(songListWebSocketManager: SongListWebSocketManager) {
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
fun addSession(uid: String, session: WebSocketServerSession) {
if (sessions[uid] != null) {
CoroutineScope(Dispatchers.Default).launch {
sessions[uid]?.close(
CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Duplicated sessions.")
)
}
}
sessions[uid] = session
}
fun removeSession(uid: String) {
sessions.remove(uid)
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedType: Int): Boolean {
val timeout = 5000L // 5 seconds timeout
val startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < timeout) {
for (frame in ws.incoming) {
if (frame is Text) {
val message = frame.readText()
if(message == "ping") {
return true
}
val data = Json.decodeFromString<SongRequest>(message)
if (data.type == SongType.ACK.value) {
return true // ACK received
}
}
}
delay(100) // Check every 100 ms
}
return false // Timeout
}
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws = sessions[uid]
try {
if(ws == null) {
delay(delayMillis)
continue
}
// Attempt to send the message
ws.sendSerialized(res)
logger.debug("Message sent successfully to $uid on attempt $attempt")
// Wait for ACK
val ackReceived = waitForAck(ws, res.type)
if (ackReceived == true) {
sentSuccessfully = true
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
}
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt. Retrying in $delayMillis ms.")
logger.warn(e.stackTraceToString())
} finally {
// Wait before retrying
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
}
webSocket("/songlist") { webSocket("/songlist") {
val session = call.sessions.get<UserSession>() val session = call.sessions.get<UserSession>()
val user = session?.id?.let { UserService.getUser(it) } val uid = session?.id
if (user == null) { if (uid == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid session"))
return@webSocket return@webSocket
} }
val uid = user.token
addSession(uid, this)
if (status[uid] == SongType.STREAM_OFF) {
CoroutineScope(Dispatchers.Default).launch {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
null,
null,
null,
))
}
removeSession(uid)
}
try { try {
songListWebSocketManager.addSession(uid, this)
for (frame in incoming) { for (frame in incoming) {
when (frame) { when (frame) {
is Text -> { is Text -> {
if (frame.readText().trim() == "ping") { val text = frame.readText().trim()
send("pong") if (text == SongListWebSocketManager.PING_MESSAGE) {
send(SongListWebSocketManager.PONG_MESSAGE)
songListWebSocketManager.handlePong(uid)
} else { } else {
val data = frame.readText().let { Json.decodeFromString<SongRequest>(it) }
// Handle song requests // Handle song requests
handleSongRequest(data, user, dispatcher, logger) text.let { Json.decodeFromString<SongRequest>(it) }.let { data ->
val user = session.id.let { UserService.getUser(it) }
if(user == null) {
songListWebSocketManager.removeSession(uid)
return@webSocket
}
handleSongRequest(data, user, dispatcher, songListWebSocketManager.logger)
}.runCatching { songListWebSocketManager.logger.error("Failed to parse WebSocket message as SongRequest.") }
} }
} }
is Ping -> send(Pong(frame.data)) is Ping -> send(Pong(frame.data))
else -> "" else -> songListWebSocketManager.logger.warn("Unsupported frame type received.")
} }
} }
} catch (e: ClosedReceiveChannelException) { } catch (e: Exception) {
logger.error("Error in WebSocket: ${e.message}") songListWebSocketManager.logger.error("WebSocket error: ${e.message}")
} finally { } finally {
removeSession(uid) songListWebSocketManager.removeSession(uid)
}
}
dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
CoroutineScope(Dispatchers.Default).launch {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
it.reqUid,
it.current?.toSerializable(),
it.next?.toSerializable(),
it.delUrl
)
)
} }
} }
} }
dispatcher.subscribe(TimerEvent::class) {
if (it.type == TimerType.STREAM_OFF) {
CoroutineScope(Dispatchers.Default).launch {
val user = UserService.getUser(it.uid)
if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
null,
null,
null,
)
)
}
}
}
}
}
suspend fun handleSongRequest( suspend fun handleSongRequest(
data: SongRequest, data: SongRequest,
@ -289,6 +166,7 @@ suspend fun handleSongRequest(
} }
} }
@Serializable @Serializable
data class SongRequest( data class SongRequest(
val type: Int, val type: Int,

View File

@ -0,0 +1,271 @@
package space.mori.chzzk_bot.webserver.utils
import io.ktor.server.websocket.WebSocketServerSession
import io.ktor.util.logging.Logger
import io.ktor.websocket.CloseReason
import io.ktor.websocket.Frame
import io.ktor.websocket.close
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.serialization.json.Json
import space.mori.chzzk_bot.common.events.SongType
import space.mori.chzzk_bot.webserver.routes.SongResponse
import java.util.concurrent.ConcurrentHashMap
import kotlin.collections.set
class SongListWebSocketManager(internal val logger: Logger) {
companion object {
private const val ACK_TIMEOUT_MS = 5000L
private const val ACK_CHECK_INTERVAL_MS = 100L
private const val MAX_RETRY_ATTEMPTS = 5
private const val RETRY_DELAY_MS = 3000L
internal const val PING_MESSAGE = "ping"
internal const val PONG_MESSAGE = "pong"
private const val KEEP_ALIVE_INTERVAL_MS = 30000L // 30초 간격으로 핑-퐁 메시지 전송
}
private val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
private val status = ConcurrentHashMap<String, SongType>()
private val lastActivity = ConcurrentHashMap<String, Long>()
private val pendingAcks = ConcurrentHashMap<String, Boolean>()
// 핑-퐁 상태 관리를 위한 맵
private val pingStatus = ConcurrentHashMap<String, Boolean>()
private val pingJobs = ConcurrentHashMap<String, kotlinx.coroutines.Job>()
private val mutex = kotlinx.coroutines.sync.Mutex()
init {
// 정기적으로 비활성 연결을 체크하는 백그라운드 작업
CoroutineScope(Dispatchers.Default).launch {
while (true) {
try {
checkInactiveSessions()
delay(60000) // 1분마다 체크
} catch (e: Exception) {
logger.error("비활성 세션 체크 중 오류 발생: ${e.message}")
}
}
}
}
suspend fun addSession(uid: String, session: WebSocketServerSession) {
mutex.lock()
try {
sessions[uid] = session
lastActivity[uid] = System.currentTimeMillis()
// 핑-퐁 작업 시작
startPingPongJob(uid, session)
logger.info("웹소켓 세션 추가됨: $uid")
} finally {
mutex.unlock()
}
}
suspend fun removeSession(uid: String) {
mutex.lock()
try {
sessions.remove(uid)
status.remove(uid)
lastActivity.remove(uid)
pendingAcks.remove(uid)
// 핑-퐁 작업 중지
stopPingPongJob(uid)
logger.info("웹소켓 세션 제거됨: $uid")
} finally {
mutex.unlock()
}
}
private fun startPingPongJob(uid: String, session: WebSocketServerSession) {
pingStatus[uid] = true
pingJobs[uid] = CoroutineScope(Dispatchers.IO).launch {
while (pingStatus[uid] == true && sessions[uid] != null) {
try {
session.send(Frame.Text(PING_MESSAGE))
logger.debug("핑 메시지 전송: $uid")
// 응답 대기 (타임아웃 처리는 메시지 수신부에서)
mutex.lock()
try {
lastActivity[uid] = System.currentTimeMillis()
} finally {
mutex.unlock()
}
delay(KEEP_ALIVE_INTERVAL_MS)
} catch (e: Exception) {
logger.error("핑-퐁 작업 중 오류 발생: ${e.message}")
// 연결에 문제가 있으면 세션을 제거
if (e is ClosedReceiveChannelException || e is java.io.IOException) {
logger.warn("연결 문제로 세션 제거: $uid")
removeSession(uid)
break
}
delay(5000) // 오류 발생 시 잠시 대기 후 재시도
}
}
}
}
private fun stopPingPongJob(uid: String) {
pingStatus[uid] = false
pingJobs[uid]?.cancel()
pingJobs.remove(uid)
}
// 비활성 세션 점검 및 제거
private suspend fun checkInactiveSessions() {
val currentTime = System.currentTimeMillis()
val inactiveTimeout = 3 * KEEP_ALIVE_INTERVAL_MS // 3번의 핑-퐁 주기 이상 응답이 없으면 비활성으로 간주
val inactiveSessions = lastActivity.entries
.filter { currentTime - it.value > inactiveTimeout }
.map { it.key }
mutex.lock()
try {
inactiveSessions.forEach { uid ->
logger.warn("비활성 세션 감지됨. 제거 중: $uid (마지막 활동: ${(currentTime - lastActivity[uid]!!) / 1000}초 전)")
try {
sessions[uid]?.close(CloseReason(CloseReason.Codes.GOING_AWAY, "비활성 연결 감지"))
} catch (e: Exception) {
logger.error("비활성 세션 닫기 실패: ${e.message}")
} finally {
removeSession(uid)
}
}
} finally {
mutex.unlock()
}
if (inactiveSessions.isNotEmpty()) {
logger.info("${inactiveSessions.size}개의 비활성 세션이 제거됨")
}
}
// 활동 기록 업데이트
suspend fun updateActivity(uid: String) {
mutex.lock()
try {
lastActivity[uid] = System.currentTimeMillis()
} finally {
mutex.unlock()
}
}
// 퐁 메시지 처리
suspend fun handlePong(uid: String) {
updateActivity(uid)
logger.debug("퐁 메시지 수신: $uid")
}
suspend fun waitForAck(ws: WebSocketServerSession): Boolean {
val sessionId = ws.hashCode().toString()
mutex.lock()
try {
pendingAcks[sessionId] = false
} finally {
mutex.unlock()
}
val startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < ACK_TIMEOUT_MS) {
mutex.lock()
try {
if (pendingAcks[sessionId] == true) {
pendingAcks.remove(sessionId)
return true
}
} finally {
mutex.unlock()
}
delay(ACK_CHECK_INTERVAL_MS)
}
mutex.lock()
try {
pendingAcks.remove(sessionId)
} finally {
mutex.unlock()
}
return false
}
suspend fun acknowledgeMessage(sessionId: String) {
mutex.lock()
try {
pendingAcks[sessionId] = true
} finally {
mutex.unlock()
}
}
suspend fun sendWithRetry(uid: String, response: SongResponse) {
val session = sessions[uid] ?: run {
logger.warn("세션을 찾을 수 없음: $uid")
return
}
var attempts = 0
val jsonResponse = Json.encodeToString(SongResponse.serializer(), response)
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
session.send(Frame.Text(jsonResponse))
updateActivity(uid)
if (waitForAck(session)) {
logger.debug("메시지 전송 성공 (시도 ${attempts + 1}): $uid")
return
} else {
logger.warn("확인 응답 없음 (시도 ${attempts + 1}): $uid")
}
} catch (e: Exception) {
logger.error("메시지 전송 실패 (시도 ${attempts + 1}): ${e.message}")
if (e is ClosedReceiveChannelException || e is java.io.IOException) {
logger.warn("연결 끊김으로 인한 세션 제거: $uid")
removeSession(uid)
return
}
}
attempts++
if (attempts < MAX_RETRY_ATTEMPTS) {
delay(RETRY_DELAY_MS)
}
}
logger.error("최대 재시도 횟수 초과 후 메시지 전송 실패: $uid")
// 여러 번 실패 후 연결이 불안정한 것으로 판단하여 세션 제거
removeSession(uid)
}
suspend fun getUserStatus(uid: String): SongType? {
mutex.lock()
return try {
status[uid]
} finally {
mutex.unlock()
}
}
suspend fun updateUserStatus(uid: String, songType: SongType) {
mutex.lock()
try {
status[uid] = songType
} finally {
mutex.unlock()
}
}
}