Converting a PNG image to a np.array while running the train method in a neural network

63 Views Asked by At

I'm trying to create a custom transform to adjust the rotation of an X-Ray for training a neural network, as seen here - Link to rotating X-rays. The function works on test images, whether it's a PNG or JPG, however, when I add it to the neural network processing, the conversion from PNG to np.array doesn't work correctly. It ends up generating a completely black image, which then (obviously) can't be used for edge detection. When stepping through the code, it does convert the PIL.Image to a np.array, but something goes wrong. Here's the code:

Adjusting Image method:

class AdjustImage(object):
    THRESHOLD = 240
    def subimage(self, image, center, theta, width, height):
        if 45 < theta <= 90:
            theta = theta - 90
            width, height = height, width
    
        theta *= math.pi / 180 # convert to rad
        v_x = (math.cos(theta), math.sin(theta))
        v_y = (-math.sin(theta), math.cos(theta))
        s_x = center[0] - v_x[0] * (width / 2) - v_y[0] * (height / 2)
        s_y = center[1] - v_x[1] * (width / 2) - v_y[1] * (height / 2)
        mapping = np.array([[v_x[0],v_y[0], s_x], [v_x[1],v_y[1], s_y]])
        return cv2.warpAffine(image, mapping, (width, height), flags=cv2.WARP_INVERSE_MAP, borderMode=cv2.BORDER_REPLICATE)
    def __call__(self, image_source):
         # First slightly crop edge - some images had a rogue 2 pixel black edge on one side
        #https://github.com/python-pillow/Pillow/issues/6765 - pil issue with pngs
        type_var = image_source.format
        image_source.show(title="Initial Image")
        if isinstance(image_source, np.ndarray):
            print("working")
        elif image_source.format == 'JPEG':
            image_source = np.array(image_source)
            image_source = image_source[:,:,::-1].copy()
        else:
            image_source = image_source.point(lambda x: x / 256)
            image_source = image_source.convert('RGB')
            image_source.show(title="After Conversion")
            image_source = np.array(image_source)
            image_source = image_source[:,:,::-1].copy()
            
            #This method was designed and tested on using cv2.imread - these steps change a PIL PNG read into that.
            #Otherwise, it will destroy the image
           
        init_crop = 5
        
        h, w = image_source.shape[:2]
        image_source = image_source[init_crop:init_crop+(h-init_crop*2), init_crop:init_crop+(w-init_crop*2)]
        
        # Add back a white border
        image_source = cv2.copyMakeBorder(image_source, 5,5,5,5, cv2.BORDER_CONSTANT, value=(255,255,255))
        
        image_gray = cv2.cvtColor(image_source, cv2.COLOR_BGR2GRAY)
        _, image_thresh = cv2.threshold(image_gray, self.THRESHOLD, 255, cv2.THRESH_TOZERO_INV)
        
        image_thresh2 = image_thresh.copy()
        image_thresh2 = cv2.Canny(image_thresh2, 100, 100, apertureSize=3)
        
        points = cv2.findNonZero(image_thresh2)
    
        centre, dimensions, theta = cv2.minAreaRect(points)
        rect = cv2.minAreaRect(points)
    
        width = int(dimensions[0])
        height = int(dimensions[1])
    
        box = cv2.boxPoints(rect)
        box = np.int0(box)
    
        temp = image_source.copy()
        cv2.drawContours(temp, [box], 0, (255,0,0), 2)
    
        M = cv2.moments(box)    
        cx = int(M['m10']/M['m00'])
        cy = int(M['m01']/M['m00'])
    
        image_patch = self.subimage(image_source, (cx, cy), (theta+90), height, width)
    
        # add back a small border
        image_patch = cv2.copyMakeBorder(image_patch, 1,1,1,1, cv2.BORDER_CONSTANT, value=(255,255,255))
    
        # Convert image to binary, edge is black. Do edge detection and convert edges to a list of points.
        # Then calculate a minimum set of points that can enclose the points.
        _, image_thresh = cv2.threshold(image_patch, self.THRESHOLD, 255, 1)
        image_thresh = cv2.Canny(image_thresh, 100, 100, 3)
        points = cv2.findNonZero(image_thresh)
        hull = cv2.convexHull(points)
    
        # Find min epsilon resulting in exactly 4 points, typically between 7 and 21
        # This is the smallest set of 4 points to enclose the image.
        for epsilon in range(3, 50):
            hull_simple = cv2.approxPolyDP(hull, epsilon, 1)
    
            if len(hull_simple) == 4:
                break
    
        hull = hull_simple
    
        # Find closest fitting image size and warp/crop to fit
        # (ie reduce scaling to a minimum)
    
        x,y,w,h = cv2.boundingRect(hull)
        target_corners = np.array([[0,0],[w,0],[w,h],[0,h]], np.float32)
    
        # Sort hull into tl,tr,br,bl order. 
        # n.b. hull is already sorted in clockwise order, we just need to know where top left is.
    
        source_corners = hull.reshape(-1,2).astype('float32')
        min_dist = 100000
        index = 0
    
        for n in range(len(source_corners)):
            x,y = source_corners[n]
            dist = math.hypot(x,y)
    
            if dist < min_dist:
                index = n
                min_dist = dist
    
        # Rotate the array so tl is first
        source_corners = np.roll(source_corners , -(2*index))
    
        try:
            transform = cv2.getPerspectiveTransform(source_corners, target_corners)
            return cv2.warpPerspective(image_patch, transform, (w,h))
    
        except:
            print ("Warp failure", image_source)
            return image_patch

    def __repr__(self):
        return self.__class__.__name__+'()'

Setting up the neural network:

#Constructing the ResNeXt model:
transforms = v2.Compose([AdjustImage(),v2.Resize([256,256]), v2.PILToTensor()])
train_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\train", transform = transforms)
test_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\test", transform = transforms)
val_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\val", transform = transforms)

batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size)
test_dataloader = DataLoader(test_ds, batch_size=batch_size)
val_dataloader = DataLoader(val_ds, batch_size = batch_size)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnext50_32x4d', pretrained=True)
print(model)
model.type(torch.LongTensor)
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr = 1e-3)

train method:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.type(torch.FloatTensor), y.type(torch.FloatTensor)
        X,y = X.to(device), y.to(device)
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y.type(torch.LongTensor).to(device))

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

Running the neural net:

epochs = 5
model.to(device)
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimiser)
    test(test_dataloader, model, loss_fn)
print("Done!")

It errors out when it tries to find the hull of nothing. I'm using pytorch's resnext-50, and haven't tweaked anything there.

Edit: After trying the step outlined here, it again works in isolation, but not in the neural network. Trying the

image_source = np.asarray(image_source).astype(float)/255.0

step and writing out to a file using cv2 just produces a black image.

Edit 2: The solution is to use the conversions:

na = (np.array(image)>>8).astype(np.uint8)
        na = Image.fromarray(na)

as the loader for the ImageFolder like so:

train_ds = datasets.ImageFolder(<path to folder>, transform = transforms, loader=lambda path: Image.fromarray((np.array(Image.open(path))>>8).astype(np.uint8)))

This loads the images correctly.

1

There are 1 best solutions below

2
Mark Setchell On

I think the following line is doing you in:

image_source = image_source.point(lambda x: x / 256)

I am assuming image_source is a PIL Image, that was loaded from a PNG. If so, there is at least one problem. If it has pixels in the range 0..255 and you divide by 256 and store the answer in an 8-bit unsigned char, you are going to get a result of either 0 or 1 (with nothing in-between) for every pixel. I guess you actually want a float in the range 0..1, so you probably want to convert to a Numpy array, then to float, then divide along these lines:

myFloatArray = np.asarray(image_source).astype(float)/256.0

And actually, I think you really want to divide by 255 rather than by 256, by the way since the maximum unsigned int will be 255 and you presumably want that to map to 1, rather than 0.99609375 (255/256).


This little example demonstrates what I think is happening:

#!/usr/bin/env python3

from PIL import Image
import numpy as np

# Make a 256x1 greyscale gradient image
gradient = Image.fromarray(np.arange(256, dtype=np.uint8))

# Print pixels as a list
print(",".join(map(str,gradient.getdata())))

# Apply your point() function
gradient = gradient.point(lambda x: x / 256)

# Print pixels as a list
print(",".join(map(str,gradient.getdata())))

Output

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1

Another potential issue is that a PNG image could be 16-bit and that will cause even more headaches because firstly the range will be 0..65535 so your scaling would be wrong, and secondly because PIL won't read 16-bit PNGs if they are colour, so be aware of that.