Compare commits

..

1 Commits

Author SHA1 Message Date
dalbodeule
8a429172cf
debug chzzk login 4 2025-02-08 04:45:01 +09:00
6 changed files with 291 additions and 425 deletions

View File

@ -40,11 +40,7 @@ object ChzzkHandler {
botUid = chzzk.loggedUser.userId botUid = chzzk.loggedUser.userId
UserService.getAllUsers().map { UserService.getAllUsers().map {
if(!it.isDisabled) if(!it.isDisabled)
try { chzzk.getChannel(it.token)?.let { token -> addUser(token, it) }
chzzk.getChannel(it.token)?.let { token -> addUser(token, it) }
} catch(e: Exception) {
logger.info("Exception: ${it.token}(${it.username}) not found. ${e.stackTraceToString()}")
}
} }
handlers.forEach { handler -> handlers.forEach { handler ->

View File

@ -5,8 +5,7 @@ enum class TimerType(var value: Int) {
TIMER(1), TIMER(1),
REMOVE(2), REMOVE(2),
STREAM_OFF(50), STREAM_OFF(50)
ACK(51)
} }
class TimerEvent( class TimerEvent(

View File

@ -25,6 +25,7 @@ import kotlinx.serialization.json.Json
import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.webserver.routes.* import space.mori.chzzk_bot.webserver.routes.*
import space.mori.chzzk_bot.webserver.utils.DiscordRatelimits import space.mori.chzzk_bot.webserver.utils.DiscordRatelimits
import wsSongListRoutes
import java.math.BigInteger import java.math.BigInteger
import java.security.SecureRandom import java.security.SecureRandom
import java.time.Duration import java.time.Duration

View File

@ -1,54 +1,115 @@
package space.mori.chzzk_bot.webserver.routes
import io.ktor.client.plugins.websocket.WebSocketException
import io.ktor.server.application.*
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.sessions.* import io.ktor.server.sessions.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.util.logging.Logger
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.* import io.ktor.websocket.Frame.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
import space.mori.chzzk_bot.common.models.SongList import space.mori.chzzk_bot.common.models.SongList
import space.mori.chzzk_bot.common.models.SongLists.uid
import space.mori.chzzk_bot.common.models.User import space.mori.chzzk_bot.common.models.User
import space.mori.chzzk_bot.common.services.SongConfigService
import space.mori.chzzk_bot.common.services.SongListService import space.mori.chzzk_bot.common.services.SongListService
import space.mori.chzzk_bot.common.services.UserService import space.mori.chzzk_bot.common.services.UserService
import space.mori.chzzk_bot.common.utils.YoutubeVideo import space.mori.chzzk_bot.common.utils.YoutubeVideo
import space.mori.chzzk_bot.common.utils.getYoutubeVideo import space.mori.chzzk_bot.common.utils.getYoutubeVideo
import space.mori.chzzk_bot.webserver.UserSession import space.mori.chzzk_bot.webserver.UserSession
import space.mori.chzzk_bot.webserver.routes.SongResponse
import space.mori.chzzk_bot.webserver.routes.toSerializable
import space.mori.chzzk_bot.webserver.utils.CurrentSong import space.mori.chzzk_bot.webserver.utils.CurrentSong
import java.io.IOException
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
fun Routing.wsSongListRoutes() { fun Routing.wsSongListRoutes() {
val sessions = ConcurrentHashMap<String, WebSocketServerSession>()
val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongListRoutes") val logger = LoggerFactory.getLogger("WSSongListRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java) val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val songListScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
// Manage all active sessions fun addSession(uid: String, session: WebSocketServerSession) {
val sessionHandlers = ConcurrentHashMap<String, SessionHandler>() if (sessions[uid] != null) {
CoroutineScope(Dispatchers.Default).launch {
// Handle application shutdown sessions[uid]?.close(
environment.monitor.subscribe(ApplicationStopped) { CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Duplicated sessions.")
sessionHandlers.values.forEach { )
songListScope.launch {
it.close(CloseReason(CloseReason.Codes.NORMAL, "Server shutting down"))
} }
} }
sessions[uid] = session
}
fun removeSession(uid: String) {
sessions.remove(uid)
}
suspend fun waitForAck(ws: WebSocketServerSession, expectedType: Int): Boolean {
val timeout = 5000L // 5 seconds timeout
val startTime = System.currentTimeMillis()
while (System.currentTimeMillis() - startTime < timeout) {
for (frame in ws.incoming) {
if (frame is Text) {
val message = frame.readText()
if(message == "ping") {
return true
}
val data = Json.decodeFromString<SongRequest>(message)
if (data.type == SongType.ACK.value) {
return true // ACK received
}
}
}
delay(100) // Check every 100 ms
}
return false // Timeout
}
suspend fun sendWithRetry(uid: String, res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
var sentSuccessfully = false
while (attempt < maxRetries && !sentSuccessfully) {
val ws = sessions[uid]
try {
if(ws == null) {
delay(delayMillis)
continue
}
// Attempt to send the message
ws.sendSerialized(res)
logger.debug("Message sent successfully to $uid on attempt $attempt")
// Wait for ACK
val ackReceived = waitForAck(ws, res.type)
if (ackReceived == true) {
sentSuccessfully = true
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
}
} catch (e: Exception) {
attempt++
logger.warn("Failed to send message to $uid on attempt $attempt. Retrying in $delayMillis ms.")
logger.warn(e.stackTraceToString())
} finally {
// Wait before retrying
delay(delayMillis)
}
}
if (!sentSuccessfully) {
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
} }
// WebSocket endpoint
webSocket("/songlist") { webSocket("/songlist") {
val session = call.sessions.get<UserSession>() val session = call.sessions.get<UserSession>()
val user: User? = session?.id?.let { UserService.getUser(it) } val user = session?.id?.let { UserService.getUser(it) }
if (user == null) { if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid SID"))
return@webSocket return@webSocket
@ -56,291 +117,175 @@ fun Routing.wsSongListRoutes() {
val uid = user.token val uid = user.token
// Ensure only one session per user addSession(uid, this)
sessionHandlers[uid]?.close(CloseReason(CloseReason.Codes.VIOLATED_POLICY, "Another session is already active."))
val handler = SessionHandler(uid, this, dispatcher, logger) if (status[uid] == SongType.STREAM_OFF) {
sessionHandlers[uid] = handler CoroutineScope(Dispatchers.Default).launch {
sendSerialized(SongResponse(
SongType.STREAM_OFF.value,
uid,
null,
null,
null,
))
}
removeSession(uid)
}
// Initialize session
handler.initialize()
// Listen for incoming frames
try { try {
for (frame in incoming) { for (frame in incoming) {
when (frame) { when (frame) {
is Frame.Text -> handler.handleTextFrame(frame.readText()) is Text -> {
is Frame.Ping -> send(Frame.Pong(frame.data)) if (frame.readText().trim() == "ping") {
else -> Unit send("pong")
} else {
val data = frame.readText().let { Json.decodeFromString<SongRequest>(it) }
// Handle song requests
handleSongRequest(data, user, dispatcher, logger)
}
}
is Ping -> send(Pong(frame.data))
else -> ""
} }
} }
} catch (e: ClosedReceiveChannelException) { } catch (e: ClosedReceiveChannelException) {
logger.info("Session closed: ${e.message}") logger.error("Error in WebSocket: ${e.message}")
} catch (e: IOException) {
logger.error("IO error: ${e.message}")
} catch (e: Exception) {
logger.error("Unexpected error: ${e.message}")
} finally { } finally {
sessionHandlers.remove(uid) removeSession(uid)
handler.close(CloseReason(CloseReason.Codes.NORMAL, "Session ended"))
} }
} }
// Subscribe to SongEvents dispatcher.subscribe(SongEvent::class) {
dispatcher.subscribe(SongEvent::class) { event -> logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
val handler = sessionHandlers[event.uid] CoroutineScope(Dispatchers.Default).launch {
songListScope.launch { val user = UserService.getUser(it.uid)
handler?.sendSongResponse(event) if (user != null) {
sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
it.reqUid,
it.current?.toSerializable(),
it.next?.toSerializable(),
it.delUrl
)
)
}
} }
} }
// Subscribe to TimerEvents dispatcher.subscribe(TimerEvent::class) {
dispatcher.subscribe(TimerEvent::class) { event -> if (it.type == TimerType.STREAM_OFF) {
if (event.type == TimerType.STREAM_OFF) { CoroutineScope(Dispatchers.Default).launch {
val handler = sessionHandlers[event.uid] val user = UserService.getUser(it.uid)
songListScope.launch { if (user != null) {
handler?.sendTimerOff() sendWithRetry(
user.token, SongResponse(
it.type.value,
it.uid,
null,
null,
null,
)
)
}
} }
} }
} }
} }
class SessionHandler( suspend fun handleSongRequest(
private val uid: String, data: SongRequest,
private val session: WebSocketServerSession, user: User,
private val dispatcher: CoroutinesEventBus, dispatcher: CoroutinesEventBus,
private val logger: Logger logger: Logger
) { ) {
private val ackMap = ConcurrentHashMap<String, CompletableDeferred<Boolean>>() if (data.maxQueue != null && data.maxQueue > 0) SongConfigService.updateQueueLimit(user, data.maxQueue)
private val sessionMutex = Mutex() if (data.maxUserLimit != null && data.maxUserLimit > 0) SongConfigService.updatePersonalLimit(user, data.maxUserLimit)
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) if (data.isStreamerOnly != null) SongConfigService.updateStreamerOnly(user, data.isStreamerOnly)
if (data.isDisabled != null) SongConfigService.updateDisabled(user, data.isDisabled)
suspend fun initialize() { when (data.type) {
// Send initial status if needed, SongType.ADD.value -> {
// For example, send STREAM_OFF if applicable data.url?.let { url ->
// This can be extended based on your requirements try {
} val youtubeVideo = getYoutubeVideo(url)
if (youtubeVideo != null) {
suspend fun handleTextFrame(text: String) { CoroutineScope(Dispatchers.Default).launch {
if (text.trim() == "ping") { SongListService.saveSong(
session.send("pong") user,
return user.token,
} url,
youtubeVideo.name,
val data = try { youtubeVideo.author,
Json.decodeFromString<SongRequest>(text) youtubeVideo.length,
} catch (e: Exception) { user.username
logger.warn("Failed to decode SongRequest: ${e.message}") )
return dispatcher.post(
} SongEvent(
user.token,
when (data.type) { SongType.ADD,
SongType.ACK.value -> handleAck(data.uid) user.token,
else -> handleSongRequest(data) CurrentSong.getSong(user),
} youtubeVideo
} )
)
private fun handleAck(requestUid: String) { }
ackMap[requestUid]?.complete(true) }
ackMap.remove(requestUid) } catch (e: Exception) {
} logger.debug("SongType.ADD Error: $uid $e")
private fun handleSongRequest(data: SongRequest) {
scope.launch {
SongRequestProcessor.process(data, uid, dispatcher, this@SessionHandler, logger)
}
}
suspend fun sendSongResponse(event: SongEvent) {
val response = SongResponse(
type = event.type.value,
uid = event.uid,
reqUid = event.reqUid,
current = event.current?.toSerializable(),
next = event.next?.toSerializable(),
delUrl = event.delUrl
)
sendWithRetry(response)
}
suspend fun sendTimerOff() {
val response = SongResponse(
type = TimerType.STREAM_OFF.value,
uid = uid,
reqUid = null,
current = null,
next = null,
delUrl = null
)
sendWithRetry(response)
}
private suspend fun sendWithRetry(res: SongResponse, maxRetries: Int = 5, delayMillis: Long = 3000L) {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(res)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap[res.uid] = ackDeferred
val ackReceived = withTimeoutOrNull(5000L) { ackDeferred.await() } ?: false
if (ackReceived) {
logger.debug("ACK received for message to $uid on attempt $attempt.")
return
} else {
logger.warn("ACK not received for message to $uid on attempt $attempt.")
} }
} catch (e: IOException) { }
logger.warn("Failed to send message to $uid on attempt $attempt: ${e.message}") }
if (e is WebSocketException) { SongType.REMOVE.value -> {
close(CloseReason(CloseReason.Codes.PROTOCOL_ERROR, "WebSocket error")) data.url?.let { url ->
return val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) {
SongListService.deleteSong(user, exactSong.uid, exactSong.name)
} }
} catch (e: CancellationException) { dispatcher.post(
throw e SongEvent(
} catch (e: Exception) { user.token,
logger.warn("Unexpected error while sending message to $uid on attempt $attempt: ${e.message}") SongType.REMOVE,
} null,
null,
attempt++ null,
delay(delayMillis) url
} )
)
logger.error("Failed to send message to $uid after $maxRetries attempts.")
}
suspend fun close(reason: CloseReason) {
try {
session.close(reason)
} catch (e: Exception) {
logger.warn("Error closing session: ${e.message}")
}
}
}
object SongRequestProcessor {
private val songMutex = Mutex()
suspend fun process(
data: SongRequest,
uid: String,
dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger
) {
val user = UserService.getUser(uid) ?: return
when (data.type) {
SongType.ADD.value -> handleAdd(data, user, dispatcher, handler, logger)
SongType.REMOVE.value -> handleRemove(data, user, dispatcher, logger)
SongType.NEXT.value -> handleNext(user, dispatcher, logger)
else -> {
// Handle other types if necessary
} }
} }
} SongType.NEXT.value -> {
private suspend fun handleAdd(
data: SongRequest,
user: User,
dispatcher: CoroutinesEventBus,
handler: SessionHandler,
logger: Logger
) {
val url = data.url ?: return
val youtubeVideo = getYoutubeVideo(url) ?: run {
logger.warn("Failed to fetch YouTube video for URL: $url")
return
}
songMutex.withLock {
SongListService.saveSong(
user,
user.token,
url,
youtubeVideo.name,
youtubeVideo.author,
youtubeVideo.length,
user.username
)
}
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.ADD,
reqUid = user.token,
current = CurrentSong.getSong(user),
next = youtubeVideo
)
)
}
private suspend fun handleRemove(
data: SongRequest,
user: User,
dispatcher: CoroutinesEventBus,
logger: Logger
) {
val url = data.url ?: return
songMutex.withLock {
val songs = SongListService.getSong(user)
val exactSong = songs.firstOrNull { it.url == url }
if (exactSong != null) {
SongListService.deleteSong(user, exactSong.uid, exactSong.name)
}
}
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.REMOVE,
delUrl = url,
reqUid = null,
current = null,
next = null,
)
)
}
private suspend fun handleNext(
user: User,
dispatcher: CoroutinesEventBus,
logger: Logger
) {
var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null
songMutex.withLock {
val songList = SongListService.getSong(user) val songList = SongListService.getSong(user)
var song: SongList? = null
var youtubeVideo: YoutubeVideo? = null
if (songList.isNotEmpty()) { if (songList.isNotEmpty()) {
song = songList[0] song = songList[0]
SongListService.deleteSong(user, song.uid, song.name) SongListService.deleteSong(user, song.uid, song.name)
} }
}
song?.let { song?.let {
youtubeVideo = YoutubeVideo( youtubeVideo = YoutubeVideo(
it.url, song.url,
it.name, song.name,
it.author, song.author,
it.time song.time
)
}
dispatcher.post(
SongEvent(
user.token,
SongType.NEXT,
song?.uid,
youtubeVideo
)
) )
CurrentSong.setSong(user, youtubeVideo)
} }
dispatcher.post(
SongEvent(
uid = user.token,
type = SongType.NEXT,
current = null,
next = youtubeVideo,
reqUid = null,
delUrl = null
)
)
CurrentSong.setSong(user, youtubeVideo)
} }
} }
@ -348,10 +293,10 @@ object SongRequestProcessor {
data class SongRequest( data class SongRequest(
val type: Int, val type: Int,
val uid: String, val uid: String,
val url: String? = null, val url: String?,
val maxQueue: Int? = null, val maxQueue: Int?,
val maxUserLimit: Int? = null, val maxUserLimit: Int?,
val isStreamerOnly: Boolean? = null, val isStreamerOnly: Boolean?,
val remove: Int? = null, val remove: Int?,
val isDisabled: Boolean? = null val isDisabled: Boolean?,
) )

View File

@ -1,20 +1,14 @@
package space.mori.chzzk_bot.webserver.routes package space.mori.chzzk_bot.webserver.routes
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
@ -23,20 +17,15 @@ import space.mori.chzzk_bot.common.utils.YoutubeVideo
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ConcurrentLinkedQueue
val songScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsSongRoutes() { fun Routing.wsSongRoutes() {
environment.monitor.subscribe(ApplicationStopped) {
songScope.cancel()
}
val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>() val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>()
val status = ConcurrentHashMap<String, SongType>() val status = ConcurrentHashMap<String, SongType>()
val logger = LoggerFactory.getLogger("WSSongRoutes") val logger = LoggerFactory.getLogger("WSSongRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val ackMap = ConcurrentHashMap<String, ConcurrentHashMap<WebSocketServerSession, CompletableDeferred<Boolean>>>()
fun addSession(uid: String, session: WebSocketServerSession) { fun addSession(uid: String, session: WebSocketServerSession) {
sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session) sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session)
} }
fun removeSession(uid: String, session: WebSocketServerSession) { fun removeSession(uid: String, session: WebSocketServerSession) {
sessions[uid]?.remove(session) sessions[uid]?.remove(session)
if(sessions[uid]?.isEmpty() == true) { if(sessions[uid]?.isEmpty() == true) {
@ -53,35 +42,27 @@ fun Routing.wsSongRoutes() {
var attempt = 0 var attempt = 0
while (attempt < maxRetries) { while (attempt < maxRetries) {
try { try {
session.sendSerialized(message) session.sendSerialized(message) // 메시지 전송 시도
val ackDeferred = CompletableDeferred<Boolean>() return true // 성공하면 true 반환
ackMap.computeIfAbsent(message.uid) { ConcurrentHashMap() }[session] = ackDeferred
val ackReceived = withTimeoutOrNull(delayMillis) { ackDeferred.await() } ?: false
if (ackReceived) {
ackMap[message.uid]?.remove(session)
return true
} else {
attempt++
logger.warn("ACK not received for message to ${message.uid} on attempt $attempt.")
}
} catch (e: Exception) { } catch (e: Exception) {
attempt++ attempt++
logger.info("Failed to send message on attempt $attempt. Retrying in $delayMillis ms.") logger.info("Failed to send message on attempt $attempt. Retrying in $delayMillis ms.")
e.printStackTrace() e.printStackTrace()
delay(delayMillis) delay(delayMillis) // 재시도 전 대기
} }
} }
return false return false // 재시도 실패 시 false 반환
} }
fun broadcastMessage(userId: String, message: SongResponse) { fun broadcastMessage(userId: String, message: SongResponse) {
val userSessions = sessions[userId] val userSessions = sessions[userId]
userSessions?.forEach { session -> userSessions?.forEach { session ->
songScope.launch { CoroutineScope(Dispatchers.Default).launch {
val success = sendWithRetry(session, message) val success = sendWithRetry(session, message)
if (!success) { if (!success) {
logger.info("Removing session for user $userId due to repeated failures.") println("Removing session for user $userId due to repeated failures.")
removeSession(userId, session) userSessions.remove(session) // 실패 시 세션 제거
} }
} }
} }
@ -90,13 +71,19 @@ fun Routing.wsSongRoutes() {
webSocket("/song/{uid}") { webSocket("/song/{uid}") {
val uid = call.parameters["uid"] val uid = call.parameters["uid"]
val user = uid?.let { UserService.getUser(it) } val user = uid?.let { UserService.getUser(it) }
if (uid == null || user == null) { if (uid == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket return@webSocket
} }
if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
addSession(uid, this) addSession(uid, this)
if(status[uid] == SongType.STREAM_OFF) { if(status[uid] == SongType.STREAM_OFF) {
songScope.launch { CoroutineScope(Dispatchers.Default).launch {
sendSerialized(SongResponse( sendSerialized(SongResponse(
SongType.STREAM_OFF.value, SongType.STREAM_OFF.value,
uid, uid,
@ -106,36 +93,33 @@ fun Routing.wsSongRoutes() {
)) ))
} }
} }
try { try {
for (frame in incoming) { for (frame in incoming) {
when(frame) { when(frame) {
is Frame.Text -> { is Frame.Text -> {
val text = frame.readText().trim() if(frame.readText().trim() == "ping") {
if(text == "ping") {
send("pong") send("pong")
} else {
val data = Json.decodeFromString<SongRequest>(text)
if (data.type == SongType.ACK.value) {
ackMap[data.uid]?.get(this)?.complete(true)
ackMap[data.uid]?.remove(this)
}
} }
} }
is Frame.Ping -> send(Frame.Pong(frame.data)) is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {} else -> {
}
} }
} }
} catch(e: ClosedReceiveChannelException) { } catch(e: ClosedReceiveChannelException) {
logger.error("Error in WebSocket: ${e.message}") logger.error("Error in WebSocket: ${e.message}")
} finally { } finally {
removeSession(uid, this) removeSession(uid, this)
ackMap[uid]?.remove(this)
} }
} }
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
dispatcher.subscribe(SongEvent::class) { dispatcher.subscribe(SongEvent::class) {
logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name) logger.debug("SongEvent: {} / {} {}", it.uid, it.type, it.current?.name)
songScope.launch { CoroutineScope(Dispatchers.Default).launch {
broadcastMessage(it.uid, SongResponse( broadcastMessage(it.uid, SongResponse(
it.type.value, it.type.value,
it.uid, it.uid,
@ -148,7 +132,7 @@ fun Routing.wsSongRoutes() {
} }
dispatcher.subscribe(TimerEvent::class) { dispatcher.subscribe(TimerEvent::class) {
if(it.type == TimerType.STREAM_OFF) { if(it.type == TimerType.STREAM_OFF) {
songScope.launch { CoroutineScope(Dispatchers.Default).launch {
broadcastMessage(it.uid, SongResponse( broadcastMessage(it.uid, SongResponse(
it.type.value, it.type.value,
it.uid, it.uid,
@ -160,6 +144,7 @@ fun Routing.wsSongRoutes() {
} }
} }
} }
@Serializable @Serializable
data class SerializableYoutubeVideo( data class SerializableYoutubeVideo(
val url: String, val url: String,
@ -167,7 +152,9 @@ data class SerializableYoutubeVideo(
val author: String, val author: String,
val length: Int val length: Int
) )
fun YoutubeVideo.toSerializable() = SerializableYoutubeVideo(url, name, author, length) fun YoutubeVideo.toSerializable() = SerializableYoutubeVideo(url, name, author, length)
@Serializable @Serializable
data class SongResponse( data class SongResponse(
val type: Int, val type: Int,

View File

@ -1,20 +1,13 @@
package space.mori.chzzk_bot.webserver.routes package space.mori.chzzk_bot.webserver.routes
import io.ktor.server.application.ApplicationStopped
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import io.ktor.websocket.* import io.ktor.websocket.*
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import org.koin.java.KoinJavaComponent.inject import org.koin.java.KoinJavaComponent.inject
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import space.mori.chzzk_bot.common.events.* import space.mori.chzzk_bot.common.events.*
@ -24,19 +17,14 @@ import space.mori.chzzk_bot.webserver.utils.CurrentTimer
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ConcurrentLinkedQueue
val timerScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)
fun Routing.wsTimerRoutes() { fun Routing.wsTimerRoutes() {
environment.monitor.subscribe(ApplicationStopped) {
timerScope.cancel()
}
val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>() val sessions = ConcurrentHashMap<String, ConcurrentLinkedQueue<WebSocketServerSession>>()
val logger = LoggerFactory.getLogger("WSTimerRoutes") val logger = LoggerFactory.getLogger("WSTimerRoutes")
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
val ackMap = ConcurrentHashMap<String, ConcurrentHashMap<WebSocketServerSession, CompletableDeferred<Boolean>>>()
fun addSession(uid: String, session: WebSocketServerSession) { fun addSession(uid: String, session: WebSocketServerSession) {
sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session) sessions.computeIfAbsent(uid) { ConcurrentLinkedQueue() }.add(session)
} }
fun removeSession(uid: String, session: WebSocketServerSession) { fun removeSession(uid: String, session: WebSocketServerSession) {
sessions[uid]?.remove(session) sessions[uid]?.remove(session)
if(sessions[uid]?.isEmpty() == true) { if(sessions[uid]?.isEmpty() == true) {
@ -44,132 +32,82 @@ fun Routing.wsTimerRoutes() {
} }
} }
suspend fun sendWithRetry(
session: WebSocketServerSession,
message: TimerResponse,
maxRetries: Int = 3,
delayMillis: Long = 2000L
): Boolean {
var attempt = 0
while (attempt < maxRetries) {
try {
session.sendSerialized(message)
val ackDeferred = CompletableDeferred<Boolean>()
ackMap.computeIfAbsent(message.uid) { ConcurrentHashMap() }[session] = ackDeferred
val ackReceived = withTimeoutOrNull(delayMillis) { ackDeferred.await() } ?: false
if (ackReceived) {
ackMap[message.uid]?.remove(session)
return true
} else {
attempt++
logger.warn("ACK not received for message to ${message.uid} on attempt $attempt.")
}
} catch (e: Exception) {
attempt++
logger.info("Failed to send message on attempt $attempt. Retrying in $delayMillis ms.")
e.printStackTrace()
delay(delayMillis)
}
}
return false
}
fun broadcastMessage(uid: String, message: TimerResponse) {
val userSessions = sessions[uid]
userSessions?.forEach { session ->
timerScope.launch {
val success = sendWithRetry(session, message.copy(uid = uid))
if (!success) {
logger.info("Removing session for user $uid due to repeated failures.")
removeSession(uid, session)
}
}
}
}
webSocket("/timer/{uid}") { webSocket("/timer/{uid}") {
val uid = call.parameters["uid"] val uid = call.parameters["uid"]
val user = uid?.let { UserService.getUser(it) } val user = uid?.let { UserService.getUser(it) }
if (uid == null || user == null) { if (uid == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID")) close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket return@webSocket
} }
if (user == null) {
close(CloseReason(CloseReason.Codes.CANNOT_ACCEPT, "Invalid UID"))
return@webSocket
}
addSession(uid, this) addSession(uid, this)
val timer = CurrentTimer.getTimer(user) val timer = CurrentTimer.getTimer(user)
if (timer?.type == TimerType.STREAM_OFF) { if(timer?.type == TimerType.STREAM_OFF) {
timerScope.launch { CoroutineScope(Dispatchers.Default).launch {
sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null, uid)) sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null))
} }
} else { } else {
timerScope.launch { CoroutineScope(Dispatchers.Default).launch {
if(timer?.type == TimerType.STREAM_OFF) { if (timer == null) {
sendSerialized(TimerResponse(TimerType.STREAM_OFF.value, null, uid)) sendSerialized(
TimerResponse(
TimerConfigService.getConfig(user)?.option ?: TimerType.REMOVE.value,
null
)
)
} else { } else {
if (timer == null) { sendSerialized(
sendSerialized( TimerResponse(
TimerResponse( timer.type.value,
TimerConfigService.getConfig(user)?.option ?: TimerType.REMOVE.value, timer.time
null,
uid
)
) )
} else { )
sendSerialized(
TimerResponse(
timer.type.value,
timer.time,
uid
)
)
}
} }
} }
} }
try { try {
for (frame in incoming) { for (frame in incoming) {
when(frame) { when(frame) {
is Frame.Text -> { is Frame.Text -> {
val text = frame.readText().trim() if(frame.readText().trim() == "ping") {
if(text == "ping") {
send("pong") send("pong")
} else {
val data = Json.decodeFromString<TimerRequest>(text)
if (data.type == TimerType.ACK.value) {
ackMap[data.uid]?.get(this)?.complete(true)
ackMap[data.uid]?.remove(this)
}
} }
} }
is Frame.Ping -> send(Frame.Pong(frame.data)) is Frame.Ping -> send(Frame.Pong(frame.data))
else -> {} else -> {
}
} }
} }
} catch(e: ClosedReceiveChannelException) { } catch(e: ClosedReceiveChannelException) {
logger.error("Error in WebSocket: ${e.message}") logger.error("Error in WebSocket: ${e.message}")
} finally { } finally {
removeSession(uid, this) removeSession(uid, this)
ackMap[uid]?.remove(this)
} }
} }
val dispatcher: CoroutinesEventBus by inject(CoroutinesEventBus::class.java)
dispatcher.subscribe(TimerEvent::class) { dispatcher.subscribe(TimerEvent::class) {
logger.debug("TimerEvent: {} / {}", it.uid, it.type) logger.debug("TimerEvent: {} / {}", it.uid, it.type)
val user = UserService.getUser(it.uid) val user = UserService.getUser(it.uid)
CurrentTimer.setTimer(user!!, it) CurrentTimer.setTimer(user!!, it)
timerScope.launch { CoroutineScope(Dispatchers.Default).launch {
broadcastMessage(it.uid, TimerResponse(it.type.value, it.time ?: "", it.uid)) sessions[it.uid]?.forEach { ws ->
ws.sendSerialized(TimerResponse(it.type.value, it.time ?: ""))
}
} }
} }
} }
@Serializable @Serializable
data class TimerResponse( data class TimerResponse(
val type: Int, val type: Int,
val time: String?, val time: String?
val uid: String
)
@Serializable
data class TimerRequest(
val type: Int,
val uid: String
) )