diff --git a/build.sbt b/build.sbt index 07d74db..011c9bb 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ import org.scalajs.linker.interface.ModuleSplitStyle import scala.sys.process.* -lazy val projectVersion = "2.4.2" +lazy val projectVersion = "2.4.3" lazy val organizationName = "ru.trett" lazy val scala3Version = "3.7.4" lazy val circeVersion = "0.14.15" @@ -120,6 +120,7 @@ lazy val server = project ).map(_ % doobieVersion), libraryDependencies += "org.jsoup" % "jsoup" % "1.21.2", libraryDependencies += "com.github.blemale" %% "scaffeine" % "5.3.0", + libraryDependencies += "io.circe" %% "circe-fs2" % "0.14.1", libraryDependencies += "org.flywaydb" % "flyway-database-postgresql" % "11.17.2" % "runtime", libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.19" % Test, libraryDependencies += "org.scalamock" %% "scalamock" % "7.5.2" % Test, diff --git a/client/src/main/scala/client/Models.scala b/client/src/main/scala/client/Models.scala index 8232cf2..f980498 100644 --- a/client/src/main/scala/client/Models.scala +++ b/client/src/main/scala/client/Models.scala @@ -24,6 +24,24 @@ object Decoders: } given Decoder[SummaryResponse] = deriveDecoder + import SummaryEvent.* + given Decoder[Content] = deriveDecoder + given Decoder[Metadata] = deriveDecoder + given Decoder[FunFact] = deriveDecoder + given Decoder[Error] = deriveDecoder + + given Decoder[SummaryEvent] = Decoder.instance { cursor => + cursor.downField("type").as[String].flatMap { + case "content" => cursor.as[Content] + case "metadata" => cursor.as[Metadata] + case "funFact" => cursor.as[FunFact] + case "error" => cursor.as[Error] + case "done" => Right(Done) + case other => + Left(io.circe.DecodingFailure(s"Unknown SummaryEvent type: $other", cursor.history)) + } + } + final class Model: val feedVar: Var[FeedItemList] = Var(List()) val channelVar: Var[ChannelList] = Var(List()) diff --git a/client/src/main/scala/client/NetworkUtils.scala b/client/src/main/scala/client/NetworkUtils.scala index 1ea0474..66935d3 100644 --- a/client/src/main/scala/client/NetworkUtils.scala +++ b/client/src/main/scala/client/NetworkUtils.scala @@ -15,6 +15,7 @@ import scala.util.Failure import scala.util.Success import scala.util.Try import ru.trett.rss.models.UserSettings +import ru.trett.rss.models.SummaryEvent object NetworkUtils { @@ -68,4 +69,19 @@ object NetworkUtils { def logout(): EventStream[Unit] = FetchStream.post("/api/logout", _.body("")).mapTo(()) + + def streamSummary(url: String): (EventStream[Try[SummaryEvent]], () => Unit) = + val bus = new EventBus[Try[SummaryEvent]] + val source = new dom.EventSource(url) + + source.onmessage = msg => + decode[SummaryEvent](msg.data.toString) match + case Right(event) => bus.emit(Success(event)) + case Left(err) => bus.emit(Failure(err)) + + source.onerror = _ => + bus.emit(Failure(new RuntimeException("Stream error"))) + source.close() + + (bus.events, () => source.close()) } diff --git a/client/src/main/scala/client/SummaryPage.scala b/client/src/main/scala/client/SummaryPage.scala index 6f16c54..674a805 100644 --- a/client/src/main/scala/client/SummaryPage.scala +++ b/client/src/main/scala/client/SummaryPage.scala @@ -4,14 +4,12 @@ import be.doeraene.webcomponents.ui5.* import be.doeraene.webcomponents.ui5.configkeys.* import client.NetworkUtils.* import com.raquo.laminar.api.L.* -import ru.trett.rss.models.{SummaryResponse, SummarySuccess, SummaryError} +import ru.trett.rss.models.SummaryEvent -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Success} object SummaryPage: - import Decoders.given - private val model = AppState.model private case class PageState( @@ -27,52 +25,71 @@ object SummaryPage: private val stateSignal = stateVar.signal private val loadMoreBus: EventBus[Unit] = new EventBus + private var currentSubscription: Option[Subscription] = None + private var currentClose: Option[() => Unit] = None + private def resetState(): Unit = stateVar.set(PageState()) - private def fetchSummaryBatch(): EventStream[Try[Option[SummaryResponse]]] = - FetchStream - .withDecoder(responseDecoder[SummaryResponse]) - .get("/api/summarize") - - private val batchObserver: Observer[Try[Option[SummaryResponse]]] = Observer { - case Success(Some(resp)) if resp.funFact.isDefined => - stateVar.update(_.copy(isLoading = false, hasMore = false, funFact = resp.funFact)) - - case Success(Some(resp)) if resp.feedsProcessed > 0 => - val (newContent, isError) = resp.result match - case SummarySuccess(html) => (html, false) - case SummaryError(message) => (message, true) - stateVar.update(s => - s.copy( - isLoading = false, - summaries = s.summaries :+ newContent, - hasError = isError, - totalProcessed = s.totalProcessed + resp.feedsProcessed, - hasMore = resp.hasMore - ) + private def cleanup(): Unit = + currentSubscription.foreach(_.kill()) + currentClose.foreach(_()) + currentSubscription = None + currentClose = None + + private def startStreaming(offset: Int): Unit = + cleanup() + + stateVar.update(s => + s.copy( + isLoading = true, + hasError = false, + summaries = if offset > 0 then s.summaries :+ "" else s.summaries ) - Home.refreshUnreadCountBus.emit(()) + ) + + val (stream, close) = NetworkUtils.streamSummary(s"/api/summarize?offset=$offset") + currentClose = Some(close) + + currentSubscription = Some(stream.foreach { + case Success(SummaryEvent.Content(text)) => + stateVar.update(s => + val newSummaries = + if s.summaries.isEmpty then List(text) + else s.summaries.init :+ (s.summaries.last + text) + s.copy(summaries = newSummaries) + ) + + case Success(SummaryEvent.Metadata(processed, remaining, more)) => + stateVar.update(s => + s.copy(totalProcessed = s.totalProcessed + processed, hasMore = more) + ) + Home.refreshUnreadCountBus.emit(()) + + case Success(SummaryEvent.FunFact(text)) => + stateVar.update(_.copy(funFact = Some(text), isLoading = false)) + + case Success(SummaryEvent.Error(msg)) => + stateVar.update(_.copy(hasError = true, isLoading = false)) + client.NotifyComponent.errorMessage(new RuntimeException(msg)) - case Success(_) => - stateVar.update(_.copy(isLoading = false, hasError = true)) + case Success(SummaryEvent.Done) => + stateVar.update(_.copy(isLoading = false)) + cleanup() - case Failure(err) => - stateVar.update(_.copy(isLoading = false, hasError = true)) - handleError(err) - } + case Failure(err) => + stateVar.update(_.copy(hasError = true, isLoading = false)) + cleanup() + handleError(err) + }(unsafeWindowOwner)) def render: Element = resetState() - val initialFetch = fetchSummaryBatch() div( cls := "main-content", - initialFetch --> batchObserver, - onMountBind { ctx => - loadMoreBus.events.flatMapSwitch { _ => - stateVar.update(_.copy(isLoading = true)) - fetchSummaryBatch() - } --> batchObserver - }, + onMountUnmountCallback(mount = _ => startStreaming(0), unmount = _ => cleanup()), + loadMoreBus.events.map(_ => stateVar.now().totalProcessed) --> (offset => + startStreaming(offset) + ), Card( _.slots.header := CardHeader( _.titleText := "AI Summary", diff --git a/scripts/local-docker/docker-compose.yml b/scripts/local-docker/docker-compose.yml index decda0c..5bd2bc7 100644 --- a/scripts/local-docker/docker-compose.yml +++ b/scripts/local-docker/docker-compose.yml @@ -24,7 +24,7 @@ services: - host.docker.internal:host-gateway server: - image: server:2.4.2 + image: server:2.4.3 container_name: rss_server restart: always depends_on: diff --git a/server/src/main/scala/ru/trett/rss/server/codecs/SummaryCodecs.scala b/server/src/main/scala/ru/trett/rss/server/codecs/SummaryCodecs.scala index 192e78e..2bbab72 100644 --- a/server/src/main/scala/ru/trett/rss/server/codecs/SummaryCodecs.scala +++ b/server/src/main/scala/ru/trett/rss/server/codecs/SummaryCodecs.scala @@ -3,7 +3,13 @@ package ru.trett.rss.server.codecs import io.circe.{Decoder, Encoder} import io.circe.generic.semiauto.* import io.circe.syntax.* -import ru.trett.rss.models.{SummaryResult, SummarySuccess, SummaryError, SummaryResponse} +import ru.trett.rss.models.{ + SummaryResult, + SummarySuccess, + SummaryError, + SummaryResponse, + SummaryEvent +} object SummaryCodecs: given Encoder[SummarySuccess] = deriveEncoder @@ -34,3 +40,17 @@ object SummaryCodecs: given Encoder[SummaryResponse] = deriveEncoder given Decoder[SummaryResponse] = deriveDecoder + + import SummaryEvent.* + given Encoder[Content] = deriveEncoder + given Encoder[Metadata] = deriveEncoder + given Encoder[FunFact] = deriveEncoder + given Encoder[Error] = deriveEncoder + + given Encoder[SummaryEvent] = Encoder.instance { + case c: Content => c.asJson.mapObject(_.add("type", "content".asJson)) + case m: Metadata => m.asJson.mapObject(_.add("type", "metadata".asJson)) + case f: FunFact => f.asJson.mapObject(_.add("type", "funFact".asJson)) + case e: Error => e.asJson.mapObject(_.add("type", "error".asJson)) + case Done => io.circe.Json.obj("type" -> "done".asJson) + } diff --git a/server/src/main/scala/ru/trett/rss/server/controllers/SummarizeController.scala b/server/src/main/scala/ru/trett/rss/server/controllers/SummarizeController.scala index d3c3935..3b900fb 100644 --- a/server/src/main/scala/ru/trett/rss/server/controllers/SummarizeController.scala +++ b/server/src/main/scala/ru/trett/rss/server/controllers/SummarizeController.scala @@ -2,8 +2,9 @@ package ru.trett.rss.server.controllers import cats.effect.IO import org.http4s.AuthedRoutes -import org.http4s.circe.CirceEntityEncoder.* +import org.http4s.ServerSentEvent import org.http4s.dsl.io.* +import io.circe.syntax.* import ru.trett.rss.server.models.User import ru.trett.rss.server.services.SummarizeService import ru.trett.rss.server.codecs.SummaryCodecs.given @@ -15,8 +16,8 @@ object SummarizeController: def routes(summarizeService: SummarizeService): AuthedRoutes[User, IO] = AuthedRoutes.of[User, IO] { case GET -> Root / "api" / "summarize" :? OffsetQueryParamMatcher(offset) as user => - for - summary <- summarizeService.getSummary(user, offset.getOrElse(0)) - response <- Ok(summary) - yield response + val stream = summarizeService + .streamSummary(user, offset.getOrElse(0)) + .map(event => ServerSentEvent(data = Some(event.asJson.noSpaces))) + Ok(stream) } diff --git a/server/src/main/scala/ru/trett/rss/server/services/SummarizeService.scala b/server/src/main/scala/ru/trett/rss/server/services/SummarizeService.scala index 05f84d8..8b92029 100644 --- a/server/src/main/scala/ru/trett/rss/server/services/SummarizeService.scala +++ b/server/src/main/scala/ru/trett/rss/server/services/SummarizeService.scala @@ -1,6 +1,7 @@ package ru.trett.rss.server.services import cats.effect.IO +import fs2.Stream import io.circe.Decoder import io.circe.Json import io.circe.generic.auto.* @@ -15,14 +16,7 @@ import org.http4s.client.Client import org.typelevel.ci.* import org.typelevel.log4cats.Logger import org.typelevel.log4cats.LoggerFactory -import ru.trett.rss.models.{ - SummaryLanguage, - SummaryModel, - SummaryResponse, - SummaryResult, - SummarySuccess, - SummaryError -} +import ru.trett.rss.models.{SummaryLanguage, SummaryModel, SummaryEvent} import ru.trett.rss.server.models.User import ru.trett.rss.server.repositories.FeedRepository import org.jsoup.Jsoup @@ -41,15 +35,17 @@ class SummarizeService(feedRepository: FeedRepository, client: Client[IO], apiKe private val logger: Logger[IO] = LoggerFactory[IO].getLogger private val batchSize = 30 - private def getEndpoint(modelId: String): Uri = + private def getEndpoint(modelId: String, stream: Boolean = false): Uri = + val method = if stream then "streamGenerateContent" else "generateContent" Uri.unsafeFromString( - s"https://generativelanguage.googleapis.com/v1beta/models/$modelId:generateContent" + s"https://generativelanguage.googleapis.com/v1beta/models/$modelId:$method" ) private def buildGeminiRequest( modelId: String, prompt: String, - temperature: Option[Double] = None + temperature: Option[Double] = None, + stream: Boolean = false ): IO[Request[IO]] = val baseConfig = Json.obj( "contents" -> Json @@ -79,7 +75,7 @@ class SummarizeService(feedRepository: FeedRepository, client: Client[IO], apiKe IO.pure( Request[IO]( method = Method.POST, - uri = getEndpoint(modelId), + uri = getEndpoint(modelId, stream), headers = Headers( Header.Raw(ci"X-goog-api-key", apiKey), Header.Raw(ci"Content-Type", "application/json") @@ -87,63 +83,56 @@ class SummarizeService(feedRepository: FeedRepository, client: Client[IO], apiKe ).withEntity(config) ) - def getSummary(user: User, offset: Int): IO[SummaryResponse] = + def streamSummary(user: User, offset: Int): Stream[IO, SummaryEvent] = val selectedModel = user.settings.summaryModel .flatMap(SummaryModel.fromString) .getOrElse(SummaryModel.default) - for - totalUnread <- feedRepository.getTotalUnreadCount(user.id) - feeds <- feedRepository.getUnreadFeeds(user, batchSize, offset) - response <- - if feeds.isEmpty && offset == 0 then - // No feeds at all - generate fun fact - generateFunFact(user, selectedModel.modelId).map(funFact => - SummaryResponse( - result = SummarySuccess(""), - hasMore = false, - feedsProcessed = 0, - totalRemaining = 0, - funFact = Some(funFact) + Stream + .eval(feedRepository.getTotalUnreadCount(user.id)) + .flatMap { totalUnread => + Stream.eval(feedRepository.getUnreadFeeds(user, batchSize, offset)).flatMap { + feeds => + val remainingAfterThis = totalUnread - offset - feeds.size + val metadata = SummaryEvent.Metadata( + feedsProcessed = feeds.size, + totalRemaining = Math.max(0, remainingAfterThis), + hasMore = remainingAfterThis > 0 ) - ) - else if feeds.isEmpty then - // No more feeds (reached end of pagination) - IO.pure( - SummaryResponse( - result = SummarySuccess(""), - hasMore = false, - feedsProcessed = 0, - totalRemaining = 0, - funFact = None - ) - ) - else - val text = feeds.map(_.description).mkString("\n") - val strippedText = Jsoup.parse(text).text() - val validatedLanguage = user.settings.summaryLanguage - .flatMap(SummaryLanguage.fromString) - .getOrElse(SummaryLanguage.English) - - for - summaryResult <- summarize( - strippedText, - validatedLanguage.displayName, - selectedModel.modelId + + Stream.emit(metadata) ++ ( + if feeds.isEmpty && offset == 0 then + Stream + .eval(generateFunFact(user, selectedModel.modelId)) + .map(SummaryEvent.FunFact(_)) ++ Stream.emit(SummaryEvent.Done) + else if feeds.isEmpty then Stream.emit(SummaryEvent.Done) + else + val text = feeds.map(_.description).mkString("\n") + val strippedText = Jsoup.parse(text).text() + val validatedLanguage = user.settings.summaryLanguage + .flatMap(SummaryLanguage.fromString) + .getOrElse(SummaryLanguage.English) + + Stream + .eval( + if user.settings.isAiMode then + feedRepository.markFeedAsRead(feeds.map(_.link), user) + else IO.unit + ) + .drain ++ summarizeStream( + strippedText, + validatedLanguage.displayName, + selectedModel.modelId + ) ++ Stream.emit(SummaryEvent.Done) ) - _ <- summaryResult match - case _: SummarySuccess if user.settings.isAiMode => - feedRepository.markFeedAsRead(feeds.map(_.link), user) - case _ => IO.unit - remainingAfterThis = totalUnread - offset - feeds.size - yield SummaryResponse( - result = summaryResult, - hasMore = remainingAfterThis > 0, - feedsProcessed = feeds.size, - totalRemaining = Math.max(0, remainingAfterThis), - funFact = None - ) - yield response + } + } + .handleErrorWith { error => + Stream.eval(logger.error(error)("Error in streamSummary")).drain ++ + Stream.emit( + SummaryEvent.Error("Error generating summary: " + error.getMessage) + ) ++ Stream.emit(SummaryEvent.Done) + } private def generateFunFact(user: User, modelId: String): IO[String] = val validatedLanguage = user.settings.summaryLanguage @@ -185,7 +174,11 @@ class SummarizeService(feedRepository: FeedRepository, client: Client[IO], apiKe } } - private def summarize(text: String, language: String, modelId: String): IO[SummaryResult] = + private def summarizeStream( + text: String, + language: String, + modelId: String + ): Stream[IO, SummaryEvent] = val prompt = s"""You must follow these rules for your response: |1. Provide only the raw text of the code. |2. Do NOT use any markdown formatting. @@ -202,42 +195,47 @@ class SummarizeService(feedRepository: FeedRepository, client: Client[IO], apiKe |13. For each topic, list the key stories with brief summaries. |Now, following these rules exactly summarize the following text. Answer in $language: $text.""".stripMargin - buildGeminiRequest(modelId, prompt).flatMap { request => - client - .run(request) - .use { response => + Stream + .eval(buildGeminiRequest(modelId, prompt, stream = true)) + .flatMap { request => + client.stream(request).flatMap { response => if response.status.isSuccess then - response - .as[GeminiResponse] + response.body + .through(fs2.text.utf8.decode) + .through(io.circe.fs2.stringArrayParser) + .through(io.circe.fs2.decoder[IO, GeminiResponse]) .map { geminiResp => geminiResp.candidates.headOption .flatMap(_.content.parts.flatMap(_.headOption)) .map(_.text) .map { text => if text.startsWith("```html") then - text.stripPrefix("```html").stripSuffix("```").trim - else text.trim - } match - case Some(html) if html.nonEmpty => SummarySuccess(html) - case _ => - SummaryError("Could not extract summary from response.") + text.stripPrefix("```html").stripSuffix("```") + else text + } + .getOrElse("") } + .map(SummaryEvent.Content(_)) else - response.bodyText.compile.string.flatMap { body => - logger.error( - s"Gemini API error: status=${response.status}, body=$body" - ) *> - IO.pure(SummaryError(s"API error: ${response.status.reason}")) - } - } - .handleErrorWith { error => - val errorMessage = error match - case _: TimeoutException => - "Summary request timed out. The AI service is taking too long to respond. Please try again with fewer feeds." - case _ => - "Error communicating with the summary API." - logger.error(error)(s"Error summarizing text: ${error.getMessage}") *> IO.pure( - SummaryError(errorMessage) - ) + Stream + .eval(response.bodyText.compile.string.flatMap { body => + logger.error( + s"Gemini API stream error: status=${response.status}, body=$body" + ) + }) + .drain ++ Stream.emit( + SummaryEvent.Error(s"API error: ${response.status.reason}") + ) } - } + } + .handleErrorWith { error => + val errorMessage = error match + case _: TimeoutException => + "Summary request timed out." + case _ => + "Error communicating with the summary API." + Stream + .eval(logger.error(error)(s"Error summarizing text: ${error.getMessage}")) + .drain ++ + Stream.emit(SummaryEvent.Error(errorMessage)) + } diff --git a/shared/src/main/scala/ru/trett/rss/models/SummaryEvent.scala b/shared/src/main/scala/ru/trett/rss/models/SummaryEvent.scala new file mode 100644 index 0000000..468d587 --- /dev/null +++ b/shared/src/main/scala/ru/trett/rss/models/SummaryEvent.scala @@ -0,0 +1,12 @@ +package ru.trett.rss.models + +sealed trait SummaryEvent + +object SummaryEvent { + case class Content(text: String) extends SummaryEvent + case class Metadata(feedsProcessed: Int, totalRemaining: Int, hasMore: Boolean) + extends SummaryEvent + case class FunFact(text: String) extends SummaryEvent + case class Error(message: String) extends SummaryEvent + case object Done extends SummaryEvent +}