Restricting stream based on URL or local file

59 Views Asked by At

I have an akka http API where users sends S3 URL to the server. Server then starts stream from AWS server and performs future action on the source. However I would like to validate the size of the incoming stream before performing any operation. But we cannot use withSizeLimit akka-http directive so I have created a custom implementation for this.

final case class SizeLimit(maxBytes: Long, contentLength: Option[Long] = None) extends Attributes.Attribute {
  def isDisabled = maxBytes < 0
}

object Limitable {
  def applyForByteStrings[Mat](source: Source[ByteString, Mat], limit: SizeLimit): Source[ByteString, Mat] =
    applyLimit(source, limit)(_.size)

  def applyForChunks[Mat](source: Source[ChunkStreamPart, Mat], limit: SizeLimit): Source[ChunkStreamPart, Mat] =
    applyLimit(source, limit)(_.data.size)

  def applyLimit[T, Mat](source: Source[T, Mat], limit: SizeLimit)(sizeOf: T => Int): Source[T, Mat] =
    if (limit.isDisabled) source withAttributes Attributes(limit) // no need to add stage, it's either there or not needed
    else source.via(new Limitable(sizeOf)) withAttributes Attributes(limit)

  private val limitableDefaults = Attributes.name("limitable")
}

final class Limitable[T](sizeOf: T => Int) extends GraphStage[FlowShape[T, T]] {
  val in = Inlet[T]("Limitable.in")
  val out = Outlet[T]("Limitable.out")
  var numPullCalls = 0
  var numPushCalls = 0
  override val shape = FlowShape.of(in, out)
  override protected val initialAttributes: Attributes = Limitable.limitableDefaults

  override def createLogic(_attributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
    private var maxBytes = -1L
    private var bytesLeft = Long.MaxValue

    @nowarn("msg=deprecated") // we need getFirst semantics
    override def preStart(): Unit = {
      _attributes.getFirst[SizeLimit] match {
        case Some(limit: SizeLimit) if limit.isDisabled =>
        // "no limit"
        case Some(SizeLimit(bytes, cl @ Some(contentLength))) =>
          if (contentLength > bytes) failStage(EntityStreamSizeException(bytes, cl))
        // else we still count but never throw an error
        case Some(SizeLimit(bytes, None)) =>
          maxBytes = bytes
          bytesLeft = bytes
        case None =>
      }
    }

    override def onPush(): Unit = {
      numPushCalls += 1
      println(s"Push calls $numPushCalls")
      val elem = grab(in)
      val temp = sizeOf(elem)
      println(s"Elem size is $temp")
      bytesLeft -= temp
      if (bytesLeft >= 0) {
        push(out, elem)
      }
      else {
        println(s"EntityStreamSizeException Bytes left $bytesLeft")
        failStage(EntityStreamSizeException(maxBytes))
      }
    }

    override def onPull(): Unit = {
      numPullCalls += 1
      println(s"Pull calls $numPullCalls")
      pull(in)
    }

    setHandlers(in, out, this)
  }
}


val filePath = Paths.get("/Users/<username>/Documents/bigfile.pdf")
val fileSource: Source[ByteString, Any] = FileIO.fromPath(filePath)
val res = Limitable.applyForByteStrings(fileSource, SizeLimit(4000000L))
val sink   = StreamConverters.asInputStream()
val result = res.runWith(sink)
val tis = TikaInputStream.get(result)

This is the custom implementation reference:https://github.com/akka/akka-http/blob/main/akka-http-core/src/main/scala/akka/http/scaladsl/model/HttpEntity.scala

After custom implementation I expect server to throw EntityStreamSizeException if file size is more than 4 Mb but its not throwing any exception. What am I missing here?

1

There are 1 best solutions below

0
Sushant Somani On

The document says

Demand flowing upstream leading to elements flowing downstream.

I was not consuming output stream any where and thus was restricted to processing of 16 elements only which is the default buffer size of Sink. One way to consume output stream is

val fileSource: Source[ByteString, Any] = FileIO.fromPath(filePath)
val res = Limitable.applyForByteStrings(fileSource, SizeLimit(4000000L))
res.runForeach(println) //This makes sink utilise the stream and further pull from GraphStage