Errors using onehot_encode incorrect input format?

93 Views Asked by At

I'm trying to use the mx.nd.onehot_encode function, which should be straightforward, but I'm getting errors that are difficult to parse. Here is the example usage I'm trying.

m0 = mx.nd.zeros(15)
mx.nd.onehot_encode(mx.nd.array([0]), m0)

I expect this to return a 15 dim vector (at same address as m0) with only the first element set to 1. Instead I get the error:

src/ndarray/./ndarray_function.h:73: Check failed: index.ndim() == 1 && proptype.ndim() == 2 OneHotEncode only support 1d index.

Neither ndarray is of dimension 2, so why am I getting this error? Is there some other input format I should be using?

1

There are 1 best solutions below

1
On BEST ANSWER

It seems that mxnet.ndarray.onehot_encode requires the target ndarray to explicitly have the shape [1, X].

I tried:

m0 = mx.nd.zeros((1, 15))
mx.nd.onehot_encode(mx.nd.array([0]), m0)

It reported no error.