Training a Unet network for brainsegmentation

68 Views Asked by At

I have two questions:

Q1: I was wondering what is the best way to feed the untet network with training data:

  1. Sending in one patient at the time, where each volume is 160x3x192x192
  2. Sending in random slices from k patients

Q2: At the moment I have done the first option, but have not recieved any good results. I am getting an oscillating dice score. For example the dice loss starts at 0.99 goes down to 0.8 and its spikes to 8, and the pattern repeats. Does anyone have an answer to why this happens?

code:

class main:
def __init__(self, args):
    self.args = args
    self.train_loader = None
    self.in_channel = None
    self.out_channel = None



def _config_dataloader(self):
    print("Starting configuration of the dataset")
    print("Collecting validation and training set")


    validation_mode = "val/"
    training_mode = "train/"

    collect = Get_mean_std(self.args.path + training_mode)
    mean,std = collect(self.args.k)
     
    mean_flair = mean["FLAIR"]
    mean_t1 = mean["T1"]

    std_flair = std["FLAIR"]
    std_t1 = std["T1"]


    train_dataset = MSdataset(self.args.path + training_mode, composed_transforms = [
                        normalize(z_norm = True, mean = mean_flair, std = std_flair),
                        normalize(z_norm = True, mean = mean_t1, std = std_t1),
                        add_channel(depth = self.args.depth), 
                        ToTensor()]
                        )
    
    validation_dataset = MSdataset(self.args.path + validation_mode, composed_transforms = [
                        normalize(z_norm = True, mean = mean_flair, std = std_flair),
                        normalize(z_norm = True, mean = mean_t1, std = std_t1),
                        add_channel(depth = self.args.depth), 
                        ToTensor()]
                        )
    

    
    train_loader = DataLoader(train_dataset, 
                              self.args.batch_size, 
                              self.args.shuffle)
    
    validation_loader = DataLoader(validation_dataset, 
                              self.args.batch_size-1, 
                              self.args.shuffle)
    
    print("Data collected. Returning dataloaders for training and validation set")
    return train_loader, validation_loader

def __call__(self, is_train = False):
    train_loader, validation_loader = self._config_dataloader()
    
    complete_data = {"train": train_loader, "validation":validation_loader }

    device = torch.device("cpu" if not torch.cuda.is_available() else self.args.device)

    unet = UNet(in_channels=3, out_channels=1, init_features=32)
    unet.to(device)
    
    optimizer = optim.Adam(unet.parameters(), lr=self.args.lr)
    dsc_loss = DiceLoss()

    loss_train = []
    loss_valid = []

    print("Starting training process. Please wait..")
    sub_batch_size = 14 
    for current_epoch in tqdm(range(self.args.epoch),total= self.args.epoch):

        for phase in ["train", "validation"]:

            if phase == "train":
                unet.train()
            
            if phase == "validation":
                unet.eval()

            for i, data_set_batch in enumerate(complete_data[phase]):
                data_dict = data_set_batch
                X, mask = data_dict["volume"], data_dict["mask"]
                X, mask = (X.to(device)).float(), mask.to(device)
                B,D,C,H,W = X.shape #
             
                mask =mask.reshape((B*D,H,W)) 
                X = X.reshape((B*D,C,H,W)) 
  
                loss_depths = 0 # Nulle ut depth loss
                with torch.set_grad_enabled(is_train):

                    for sub_batches in tqdm(range(0,X.shape[0]-sub_batch_size)): 
                
                  
                        predicted = unet(X[sub_batches: sub_batches + sub_batch_size,:,:,:])
                        loss = dsc_loss(predicted.squeeze(1), mask[sub_batches: sub_batches + sub_batch_size,:,:])
                      
                        if phase == "train":
                       
                            loss_depths = loss_depths + loss
                        if phase == "validation":
                            continue
                if phase == "train":
              
                    loss_train.append(loss_depths)
                    loss_depths.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                 

    print("Training and validation is done. Exiting program and returning loss")  
    return loss_train

Notice I have not fully implemented the section for validation, I just wanted to see how the network learns first. Thanks!

0

There are 0 best solutions below