How can I create a struct of native pointers in JCuda

240 Views Asked by At

I have a CUDA kernel that takes a list of structs.

kernel<<<blockCount,blockSize>>>(MyStruct *structs);

Each struct contains 3 pointers.

typedef struct __align(16)__ {
    float* pointer1;
    float* pointer2;
    float* pointer3;
}

I have three device arrays containing floats and each pointer within the struct points to a float within one of the three device array.

The list of structs represents a tree/graph structure which allows the kernel to execute recursive operations, depending on the order of the list of structs that is sent to the kernel. (This bit works in C++ so is not associated to my problem)

What I would like to do is be able to send my struct of pointers from JCuda. I understand that this isn't natively possible unless it is flattened to a padded array as in this post.

I understand all the issues with alignment and padding that may happen when sending a list of structs, it's essentially a repeating padded array which I am fine with.

The bit I am not sure how to do, is populate my flattened struct buffer with pointers, for example, I would think i can do something like this:

Pointer A = ....(underlying device array1)
Pointer B = ....(underlying device array2)
Pointer C = ....(underlying device array3)

ByteBuffer structListBuffer = ByteBuffer.allocate(16*noSteps);
for(int x = 0; x<noSteps; x++) {
    // Get the underlying pointer values
    long pointer1 = A.withByteOffset(getStepOffsetA(x)).someGetUnderlyingPointerValueFunction();
    long pointer2 = B.withByteOffset(getStepOffsetB(x)).someGetUnderlyingPointerValueFunction();
    long pointer3 = C.withByteOffset(getStepOffsetC(x)).someGetUnderlyingPointerValueFunction();

    // Build the struct
    structListBuffer.asLongBuffer().append(pointer1);
    structListBuffer.asLongBuffer().append(pointer2);
    structListBuffer.asLongBuffer().append(pointer3);
    structListBuffer.asLongBuffer().append(0); //padding
}

structListBuffer would then contain a list of structs in the way that the kernel would expect it.

So is there any way to do the someGetUnderlyingPointerValueFunction() from a ByteBuffer?

1

There are 1 best solutions below

1
On BEST ANSWER

If I understood everything correctly, the main point of the question is whether there is such a magic function like

long address = pointer.someGetUnderlyingPointerValueFunction();

that returns the address of the native pointer.

The short answer: No, there is no such function.

(Side note: A similar functionality was already requested in quite a while ago, but I have not yet added it. Mainly because such a function does not make sense for pointers to Java arrays or (non-direct) byte buffers. Additionally, manually handling structs with their paddings and alignments, and pointers with different sizes on 32 and 64 bit machines, and buffers that are big- or little endian is an endless source of headaches. But I see the point, and the possible application case, and so I'll most likely add something like a getAddress() function. Maybe only to the CUdeviceptr class, where it definitely makes sense - at least more than in the Pointer class. People will use this method to do odd things, and they will do things that will cause nasty crashes of the VM, but JCuda itself is such a thin abstraction layer that there is no safety net in this regard anyhow...)


That said, you can work around the current limitation, with a method like this:

private static long getPointerAddress(CUdeviceptr p)
{
    // WORKAROUND until a method like CUdeviceptr#getAddress exists
    class PointerWithAddress extends Pointer
    {
        PointerWithAddress(Pointer other)
        {
            super(other);
        }
        long getAddress()
        {
            return getNativePointer() + getByteOffset();
        }
    }
    return new PointerWithAddress(p).getAddress();
}

Of course, this is ugly and clearly contradicts the intention of making the getNativePointer() and getByteOffset() methods protected. But it might eventually be replaced with some "official" method:

private static long getPointerAddress(CUdeviceptr p)
{
    return p.getAddress();
}

and until now, this is probably the solution that is closest to what you can do on the C side.


Here is an example that I wrote for testing this. The kernel is only a dummy kernel, that fills the structure with "identifiable" values (to see whether they end up in the right place), and is supposed to be launched with 1 thread only:

typedef struct __declspec(align(16)) {
    float* pointer1;
    float* pointer2;
    float* pointer3;
} MyStruct;

extern "C"
__global__ void kernel(MyStruct *structs)
{
    structs[0].pointer1[0] = 1.0f;
    structs[0].pointer1[1] = 1.1f;
    structs[0].pointer1[2] = 1.2f;

    structs[0].pointer2[0] = 2.0f;
    structs[0].pointer2[1] = 2.1f;
    structs[0].pointer2[2] = 2.2f;

    structs[0].pointer3[0] = 3.0f;
    structs[0].pointer3[1] = 3.1f;
    structs[0].pointer3[2] = 3.2f;

    structs[1].pointer1[0] = 11.0f;
    structs[1].pointer1[1] = 11.1f;
    structs[1].pointer1[2] = 11.2f;

    structs[1].pointer2[0] = 12.0f;
    structs[1].pointer2[1] = 12.1f;
    structs[1].pointer2[2] = 12.2f;

    structs[1].pointer3[0] = 13.0f;
    structs[1].pointer3[1] = 13.1f;
    structs[1].pointer3[2] = 13.2f;
}

This kernel is launched in the following program (Note: The compilation of the PTX file is done here on the fly, with settings that may not match your application case. In doubt, you may compile your PTX file manually).

The pointer1, pointer2 and pointer3 pointers of each struct are initialized so that they point to consecutive elements of the device buffers A, B and C, respectively, each with an offset that allows identifying the values that are written by the kernel. (Note that I tried to handle the two possible cases of running this either on a 32bit- or a 64bit machine, which implies different pointer sizese - although, currently, I can only test the 32bit version)

import static jcuda.driver.JCudaDriver.*;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;

import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;


public class JCudaPointersInStruct 
{
    public static void main(String args[]) throws IOException
    {
        JCudaDriver.setExceptionsEnabled(true);
        String ptxFileName = preparePtxFile("JCudaPointersInStructKernel.cu");
        cuInit(0);
        CUdevice device = new CUdevice();
        cuDeviceGet(device, 0);
        CUcontext context = new CUcontext();
        cuCtxCreate(context, 0, device);
        CUmodule module = new CUmodule();
        cuModuleLoad(module, ptxFileName);
        CUfunction function = new CUfunction();
        cuModuleGetFunction(function, module, "kernel");

        int numElements = 9;
        CUdeviceptr A = new CUdeviceptr();
        cuMemAlloc(A, numElements * Sizeof.FLOAT);
        cuMemsetD32(A, 0, numElements);
        CUdeviceptr B = new CUdeviceptr();
        cuMemAlloc(B, numElements * Sizeof.FLOAT);
        cuMemsetD32(B, 0, numElements);
        CUdeviceptr C = new CUdeviceptr();
        cuMemAlloc(C, numElements * Sizeof.FLOAT);
        cuMemsetD32(C, 0, numElements);

        int numSteps = 2;
        int sizeOfStruct = Sizeof.POINTER * 4;
        ByteBuffer hostStructsBuffer = 
            ByteBuffer.allocate(numSteps * sizeOfStruct);
        if (Sizeof.POINTER == 4)
        {
            IntBuffer b = hostStructsBuffer.order(
                ByteOrder.nativeOrder()).asIntBuffer();
            for(int x = 0; x<numSteps; x++) 
            {
                CUdeviceptr pointer1 = A.withByteOffset(getStepOffsetA(x));
                CUdeviceptr pointer2 = B.withByteOffset(getStepOffsetB(x));
                CUdeviceptr pointer3 = C.withByteOffset(getStepOffsetC(x));

                //System.out.println("Step "+x+" pointer1 is "+pointer1);
                //System.out.println("Step "+x+" pointer2 is "+pointer2);
                //System.out.println("Step "+x+" pointer3 is "+pointer3);

                b.put((int)getPointerAddress(pointer1));
                b.put((int)getPointerAddress(pointer2));
                b.put((int)getPointerAddress(pointer3));
                b.put(0);
            }
        }
        else
        {
            LongBuffer b = hostStructsBuffer.order(
                ByteOrder.nativeOrder()).asLongBuffer();
            for(int x = 0; x<numSteps; x++) 
            {
                CUdeviceptr pointer1 = A.withByteOffset(getStepOffsetA(x));
                CUdeviceptr pointer2 = B.withByteOffset(getStepOffsetB(x));
                CUdeviceptr pointer3 = C.withByteOffset(getStepOffsetC(x));

                //System.out.println("Step "+x+" pointer1 is "+pointer1);
                //System.out.println("Step "+x+" pointer2 is "+pointer2);
                //System.out.println("Step "+x+" pointer3 is "+pointer3);

                b.put(getPointerAddress(pointer1));
                b.put(getPointerAddress(pointer2));
                b.put(getPointerAddress(pointer3));
                b.put(0);
            }
        }

        CUdeviceptr structs = new CUdeviceptr();
        cuMemAlloc(structs, numSteps * sizeOfStruct);
        cuMemcpyHtoD(structs, Pointer.to(hostStructsBuffer), 
            numSteps * sizeOfStruct);

        Pointer kernelParameters = Pointer.to(
            Pointer.to(structs)
        );
        cuLaunchKernel(function, 
            1, 1, 1, 
            1, 1, 1, 
            0, null, kernelParameters, null);
        cuCtxSynchronize();


        float hostA[] = new float[numElements];
        cuMemcpyDtoH(Pointer.to(hostA), A, numElements * Sizeof.FLOAT);
        float hostB[] = new float[numElements];
        cuMemcpyDtoH(Pointer.to(hostB), B, numElements * Sizeof.FLOAT);
        float hostC[] = new float[numElements];
        cuMemcpyDtoH(Pointer.to(hostC), C, numElements * Sizeof.FLOAT);

        System.out.println("A "+Arrays.toString(hostA));
        System.out.println("B "+Arrays.toString(hostB));
        System.out.println("C "+Arrays.toString(hostC));
    }

    private static long getStepOffsetA(int x)
    {
        return x * Sizeof.FLOAT * 4 + 0 * Sizeof.FLOAT;
    }
    private static long getStepOffsetB(int x)
    {
        return x * Sizeof.FLOAT * 4 + 1 * Sizeof.FLOAT;
    }
    private static long getStepOffsetC(int x)
    {
        return x * Sizeof.FLOAT * 4 + 2 * Sizeof.FLOAT;
    }


    private static long getPointerAddress(CUdeviceptr p)
    {
        // WORKAROUND until a method like CUdeviceptr#getAddress exists
        class PointerWithAddress extends Pointer
        {
            PointerWithAddress(Pointer other)
            {
                super(other);
            }
            long getAddress()
            {
                return getNativePointer() + getByteOffset();
            }
        }
        return new PointerWithAddress(p).getAddress();
    }




    //-------------------------------------------------------------------------
    // Ignore this - in practice, you'll compile the PTX manually
    private static String preparePtxFile(String cuFileName) throws IOException
    {
        int endIndex = cuFileName.lastIndexOf('.');
        if (endIndex == -1)
        {
            endIndex = cuFileName.length()-1;
        }
        String ptxFileName = cuFileName.substring(0, endIndex+1)+"ptx";
        File cuFile = new File(cuFileName);
        if (!cuFile.exists())
        {
            throw new IOException("Input file not found: "+cuFileName);
        }
        String modelString = "-m"+System.getProperty("sun.arch.data.model");
        String command =
            "nvcc " + modelString + " -ptx -arch sm_11 -lineinfo "+
            cuFile.getPath()+" -o "+ptxFileName;
        System.out.println("Executing\n"+command);
        Process process = Runtime.getRuntime().exec(command);
        String errorMessage =
            new String(toByteArray(process.getErrorStream()));
        String outputMessage =
            new String(toByteArray(process.getInputStream()));
        int exitValue = 0;
        try
        {
            exitValue = process.waitFor();
        }
        catch (InterruptedException e)
        {
            Thread.currentThread().interrupt();
            throw new IOException(
                "Interrupted while waiting for nvcc output", e);
        }

        if (exitValue != 0)
        {
            System.out.println("nvcc process exitValue "+exitValue);
            System.out.println("errorMessage:\n"+errorMessage);
            System.out.println("outputMessage:\n"+outputMessage);
            throw new IOException(
                "Could not create .ptx file: "+errorMessage);
        }
        System.out.println("Finished creating PTX file");
        return ptxFileName;
    }
    private static byte[] toByteArray(InputStream inputStream)
        throws IOException
    {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte buffer[] = new byte[8192];
        while (true)
        {
            int read = inputStream.read(buffer);
            if (read == -1)
            {
                break;
            }
            baos.write(buffer, 0, read);
        }
        return baos.toByteArray();
    }

}

The result is, as expected/desired:

A [1.0, 1.1, 1.2, 0.0, 11.0, 11.1, 11.2, 0.0, 0.0]
B [0.0, 2.0, 2.1, 2.2, 0.0, 12.0, 12.1, 12.2, 0.0]
C [0.0, 0.0, 3.0, 3.1, 3.2, 0.0, 13.0, 13.1, 13.2]