My custom metal image filter is slow. How can I make it faster?

582 Views Asked by At

I've seen a lot of other's online tutorial that are able to achieve 0.0X seconds mark on filtering an image. Meanwhile my code here took 1.09 seconds to filter an image.(Just to reduce brightness by half).

edit after first comment time measured with 2 methods

  • Date() timeinterval , when the button “apply filter” tapped and after the apply filter function is done running
  • build it on iphone and count manually with my timer on my watch

Since I'm new to metal & kernel stuff, I don't really know the difference between my code and those tutorials that achieve faster result. Which part of my code can be improved/ use different approach to make it a lot faster.

here's my kernel code

#include <metal_stdlib>
using namespace metal;
kernel void black(
               texture2d<float, access::write> outTexture [[texture(0)]],
               texture2d<float, access::read> inTexture [[texture(1)]],
               uint2 id [[thread_position_in_grid]]) {
    float3 val = inTexture.read(id).rgb;
    float r = val.r / 4;
    float g = val.g / 4;
float b = val.b / 2;
float4 out = float4(r, g, b, 1.0);
outTexture.write(out.rgba, id);
}

this is my swift code

import Metal
import MetalKit

 // UIImage -> CGImage -> MTLTexture -> COMPUTE HAPPENS |
 //                 UIImage <- CGImage <- MTLTexture <--
 class Filter {

var device: MTLDevice
var defaultLib: MTLLibrary?
var grayscaleShader: MTLFunction?
var commandQueue: MTLCommandQueue?
var commandBuffer: MTLCommandBuffer?
var commandEncoder: MTLComputeCommandEncoder?
var pipelineState: MTLComputePipelineState?

var inputImage: UIImage
var height, width: Int

// most devices have a limit of 512 threads per group
let threadsPerBlock = MTLSize(width: 32, height: 32, depth: 1)

init(){
    
    print("initialized")
    self.device = MTLCreateSystemDefaultDevice()!
    print(device)
    
    //changes:  I did do catch try, and use bundle parameter when making make default library
    
    let frameworkBundle = Bundle(for: type(of: self))
    print(frameworkBundle)
    
   
    self.defaultLib = device.makeDefaultLibrary()
   

    
    
    self.grayscaleShader = defaultLib?.makeFunction(name: "black")
    self.commandQueue = self.device.makeCommandQueue()
    
    self.commandBuffer = self.commandQueue?.makeCommandBuffer()
    self.commandEncoder = self.commandBuffer?.makeComputeCommandEncoder()
    
    
    //ERROR HERE
    if let shader = grayscaleShader {
        print("in")
        self.pipelineState = try? self.device.makeComputePipelineState(function: shader)
        
    } else { fatalError("unable to make compute pipeline") }
    
    self.inputImage = UIImage(named: "stockImage")!
    self.height = Int(self.inputImage.size.height)
    self.width = Int(self.inputImage.size.width)
    
}

func getCGImage(from uiimg: UIImage) -> CGImage? {
    
    UIGraphicsBeginImageContext(uiimg.size)
    uiimg.draw(in: CGRect(origin: .zero, size: uiimg.size))
    let contextImage = UIGraphicsGetImageFromCurrentImageContext()
    UIGraphicsEndImageContext()
    
    return contextImage?.cgImage
    
}

func getMTLTexture(from cgimg: CGImage) -> MTLTexture {
    
    let textureLoader = MTKTextureLoader(device: self.device)
    
    do{
        let texture = try textureLoader.newTexture(cgImage: cgimg, options: nil)
        let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: texture.pixelFormat, width: width, height: height, mipmapped: false)
        textureDescriptor.usage = [.shaderRead, .shaderWrite]
        return texture
    } catch {
        fatalError("Couldn't convert CGImage to MTLtexture")
    }
    
}

func getCGImage(from mtlTexture: MTLTexture) -> CGImage? {
    
    var data = Array<UInt8>(repeatElement(0, count: 4*width*height))
    
    mtlTexture.getBytes(&data,
                        bytesPerRow: 4*width,
                        from: MTLRegionMake2D(0, 0, width, height),
                        mipmapLevel: 0)
    
    let bitmapInfo = CGBitmapInfo(rawValue: (CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.premultipliedLast.rawValue))
    
    let colorSpace = CGColorSpaceCreateDeviceRGB()
    
    let context = CGContext(data: &data,
                            width: width,
                            height: height,
                            bitsPerComponent: 8,
                            bytesPerRow: 4*width,
                            space: colorSpace,
                            bitmapInfo: bitmapInfo.rawValue)
    
    return context?.makeImage()
}

func getUIImage(from cgimg: CGImage) -> UIImage? {
    return UIImage(cgImage: cgimg)
}

func getEmptyMTLTexture() -> MTLTexture? {
    
    let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
        pixelFormat: MTLPixelFormat.rgba8Unorm,
        width: width,
        height: height,
        mipmapped: false)
    
    textureDescriptor.usage = [.shaderRead, .shaderWrite]
    
    return self.device.makeTexture(descriptor: textureDescriptor)
}

func getInputMTLTexture() -> MTLTexture? {
    if let inputImage = getCGImage(from: self.inputImage) {
        return getMTLTexture(from: inputImage)
    }
    else { fatalError("Unable to convert Input image to MTLTexture") }
}

func getBlockDimensions() -> MTLSize {
    let blockWidth = width / self.threadsPerBlock.width
    let blockHeight = height / self.threadsPerBlock.height
    return MTLSizeMake(blockWidth, blockHeight, 1)
}

func applyFilter() -> UIImage? {
    print("start")
    let date = Date()
    print(date)
    
    if let encoder = self.commandEncoder, let buffer = self.commandBuffer,
        let outputTexture = getEmptyMTLTexture(), let inputTexture = getInputMTLTexture() {
        
        encoder.setTextures([outputTexture, inputTexture], range: 0..<2)
        encoder.setComputePipelineState(self.pipelineState!)
        encoder.dispatchThreadgroups(self.getBlockDimensions(), threadsPerThreadgroup: threadsPerBlock)
        encoder.endEncoding()
        
        buffer.commit()
        buffer.waitUntilCompleted()
        
        guard let outputImage = getCGImage(from: outputTexture) else { fatalError("Couldn't obtain CGImage from MTLTexture") }
        
        print("stop")
        
        let date2 = Date()
        print(date2.timeIntervalSince(date))
        return getUIImage(from: outputImage)
        
    } else { fatalError("optional unwrapping failed") }
    
}


}
2

There are 2 best solutions below

0
On

In case someone still need the answer, I found a different approach which is make it as custom CIFilter. It works pretty fast and super easy to undestand!

0
On

You using UIImage, CGImage. These objects stored in CPU memory.

Need implement code with using just CIImage or MTLTexture. These object are storing in GPU memory and have best performace.