How to find the probability cutoff that maximize inner product of two tensors?

41 Views Asked by At

I have two tensors:

import torch

target = torch.randint(2, (3,5)) #tensor of 0s & 1s
pred = torch.rand(3, 5) #tensor of prob

# transformed_pred = ?

How can I choose a cutoff probability to transform pred into a tensor of 0s & 1s (transformed_pred) so that the dot product between target and transformed_pred is maximized?

Thanks!

2

There are 2 best solutions below

1
Christoph On

Maybe I am misunderstanding your question, but the cutoff probability should be 0.0. --> If you always choose 1, your dot product will always be equal to the sum of the non-zero elements in your target tensor, which is the maximum.

Are there any other constraints you haven't mentioned?

0
inverted_index On

If I understand your question correctly, you want to determine a cutoff point (as probability threshold) to convert the probability values in tensor pred to binary labels, such that the dot product between pred and target maximizes.

Assuming that, I suggest the following steps:

  • Sort the elements in pred and iterate over these values as potential cutoff candidates.
  • For each potential cutoff, create a binary version of pred where values greater than or equal to the cutoff are 1, and others are 0.
  • Compute the dot product of target and the binary version of pred.
  • Keep track of the cutoff that gives the maximum dot product. You may want to plot the dot product over the cutoff values to find the best cutoff point.

Here's a refined code doing above:

# Flattening the tensors for easier processing
flat_target = target.flatten()
flat_pred = pred.flatten()

# Initialize variables to track the best cutoff and maximum dot product
best_cutoff = None
max_dot_product = -1

# Sort the unique values in flat_pred to test each as a potential cutoff
unique_pred_values = torch.unique(flat_pred).sort().values

for cutoff in unique_pred_values:
    # Create a binary version of pred based on the current cutoff
    transformed_pred = (flat_pred >= cutoff).type(torch.int)

    # Compute the dot product with the target
    dot_product = torch.dot(flat_target, transformed_pred)

    # Check if this dot product is the maximum found so far
    if dot_product > max_dot_product:
        max_dot_product = dot_product
        best_cutoff = cutoff

# Transform pred using the best cutoff
transformed_pred = (pred >= best_cutoff).type(torch.int)