Write strategies to generate array shapes with total size less than certain value

249 Views Asked by At

I am trying to write a strategy generating array shapes of size 4 and product of all dims less than a given value.(say 16728).

That means search space for this has a root at (1,1,1,1) and 4 leaves as (16728, 1,1,1), (1,16728,1,1), (1,1,16728,1), (1, 1,1,16728)

Code that I am using:

# test_shapes.py
import numpy as np
from hypothesis import settings, HealthCheck, given
from hypothesis.extra.numpy import array_shapes


@settings(max_examples=10000, suppress_health_check=HealthCheck.all())
@given(shape=array_shapes(min_dims=4,max_dims=4,min_side=1,max_side=16728).filter(lambda x: np.prod(x) < 16728))
def test_shape(shape):
    print(f"testing shape: {shape}")

is not performant enough. Filtering leads to too many rejected examples and randomization does not explores paths other than to leaf (16728, 1, 1, 1).

pytest test_shapes.py --hypothesis-show-statistics

test_shapes.py::test_shape:

  - during generate phase (211.31 seconds):
    - Typical runtimes: 0-1 ms, ~ 84% in data generation
    - 51 passing examples, 0 failing examples, 99949 invalid examples
    - Events:
      * 99.95%, Retried draw from array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728) to satisfy filter
      * 99.95%, Aborted test because unable to satisfy array_shapes(max_dims=4, max_side=16728, min_dims=4).filter(lambda x: np.prod(x) < 16728)

  - Stopped because settings.max_examples=10000, but < 10% of examples satisfied assumptions

Is there a better way to write strategy in hypothesis, that explores paths to other leaves equally well?

2

There are 2 best solutions below

2
On BEST ANSWER

Good question! This is a pretty general trick: instead of using a filter, we ensure that every example is valid by construction:

import numpy as np
from hypothesis import given, strategies as st

@st.composite
def small_shapes(draw, *, ndims=4, max_elems=16728):
    # Instead of filtering, we calculate the "remaining cap" if the product
    # of our side lengths is to remain <= max_elems.  Ensuring this by
    # construction is much more efficient than filtering.
    shape = []
    for _ in range(ndims):
        side = draw(st.integers(1, max_elems))
        max_elems //= side
        shape.append(side)
    # However, it *does* bias towards having smaller sides for later
    # dimensions, which we correct by shuffling the list.
    shuffled = draw(st.permutations(shape))
    return tuple(shuffled)

@given(shape=small_shapes())
def test_shape(shape):
    print(f"testing shape: {shape}")
    assert 1 <= np.prod(shape) <= 16728

The "shuffle to remove bias" step is also a reusable tip. Finally - though I didn't need to here - the best option is often to use a constructive approach to make it more likely that the data will be valid... and then apply a filter to take care of the remaining 5-10% of examples that it didn't manage.

0
On

I played around with the solution provided by Zac, and I think I found a solution which exhibits better shrinking behavior, while only being slightly less efficient.

import numpy as np
from hypothesis import given, strategies as st, settings


@st.composite
def small_shapes(draw, *, ndims=4, max_elems=16728):
    # We try the naive strategy first, we might get lucky, even if that's
    # unlikely, but this has good shrinking behaviour towards a valid shape,
    # so shrinking may make our luck.
    shape = [draw(st.integers(1, max_elems)) for _ in range(ndims)]
    # Rather than discarding invalid shapes, we try to fix them up, going
    # through the indices in a random order to avoid any bias. This ensures 
    # that the final shape is meaningfully connected to our initial draw.
    shuffled_indices = draw(st.permutations(range(ndims)))
    for index in shuffled_indices:
        side = shape[index]
        if side > max_elems:
            # This side is too big, so we replace it with a valid draw.
            side = draw(st.integers(1, max_elems))
            shape[index] = side
        max_elems //= side
    return tuple(shape)

This produced good minimal examples for both of these tests:

@settings(database=None)
@given(shape=small_shapes(ndims=4, max_elems=104857600))
def test_shape_not_1(shape):
    if shape[1] == 1:
        assert 0, "not supported"


@settings(database=None)
@given(shape=small_shapes(ndims=4, max_elems=10_4857_600))
def test_shape_not_big(shape):
    if shape[1] >= 104_857_600 // 2:
        assert 0, "not supported"