訂閱
糾錯
加入自媒體

UNet分割脊柱

隨著我們每天收集更多數(shù)據(jù),人工智能(AI)將越來越多地應(yīng)用于醫(yī)療領(lǐng)域。人工智能在醫(yī)療領(lǐng)域的一個關(guān)鍵應(yīng)用是診斷。醫(yī)療診斷中的人工智能有助于決策、管理、自動化等。

脊柱是肌肉骨骼系統(tǒng)的重要組成部分,支撐著身體及其器官結(jié)構(gòu),同時在我們的活動性和負荷轉(zhuǎn)移中發(fā)揮著重要作用。它還能保護脊髓免受撞擊造成的損傷和機械沖擊。

在自動化脊柱處理管道中,脊柱標(biāo)記和分割是兩項基本任務(wù)。

可靠、準(zhǔn)確的脊柱圖像處理有望為臨床決策支持系統(tǒng)提供幫助,用于脊柱和骨骼健康的診斷、手術(shù)規(guī)劃和基于人群的分析。設(shè)計脊柱處理的自動化算法具有挑戰(zhàn)性,這主要是因為解剖學(xué)和采集協(xié)議有相當(dāng)大的差異,以及公開可用數(shù)據(jù)的嚴重短缺。

在這個博客中,我將只關(guān)注給定CT掃描數(shù)據(jù)集中脊柱的分割。標(biāo)記每一個椎骨和進一步診斷的任務(wù)沒有包含在這個博客中,可以作為這個任務(wù)的延續(xù)。

脊柱或脊柱分割是所有脊柱形態(tài)學(xué)和病理學(xué)自動量化應(yīng)用中的關(guān)鍵步驟。

隨著深度學(xué)習(xí)的到來,對于計算機斷層掃描(CT)這樣的任務(wù)來說,大而多樣的數(shù)據(jù)是一個主要的熱門資源。然而,目前還沒有一個大規(guī)模的公共數(shù)據(jù)集。

VerSe是一個大型、多探測器、多站點的CT脊柱數(shù)據(jù)集,由355名患者的374次掃描組成。2019年和2020年都有數(shù)據(jù)集。在本博客中,我將兩個數(shù)據(jù)集合并為一個數(shù)據(jù)集,以從更多數(shù)據(jù)中獲益。

這些數(shù)據(jù)是根據(jù)CC BY-SA 4.0許可證提供的,因此完全是開源的。

NIfTI(神經(jīng)成像信息技術(shù)倡議)是神經(jīng)成像的一種文件格式。NIfTI文件在神經(jīng)科學(xué)甚至神經(jīng)放射學(xué)研究的成像信息學(xué)中非常常用。每個NIfTI文件包含多達7維的元數(shù)據(jù),并支持多種數(shù)據(jù)類型。

前三個維度用于定義三個空間維度x、y和z,而第四個維度用于定義時間點t。其余維度(從第五個維度到第七個維度)用于其他用途。然而,第五維仍然可以有一些預(yù)定義的用途,例如存儲特定于體素的分布參數(shù)或保存基于向量的數(shù)據(jù)。

ITK-SNAP是一個用于在3D醫(yī)學(xué)圖像中分割結(jié)構(gòu)的軟件應(yīng)用程序。它是可以安裝在不同平臺上的開源軟件。我用它可以在3D視圖中可視化NifTi文件,以及在原始圖像上加載和覆蓋3D遮罩。我強烈建議將其用于此任務(wù)。

計算機斷層掃描(CT)是一種x射線成像程序,在該程序中,x射線以快速旋轉(zhuǎn)的速度對準(zhǔn)患者。機器收集的信號將存儲在計算機中,以生成身體的橫截面圖像,也稱為“切片”。

這些切片被稱為斷層圖像,包含比常規(guī)x射線更詳細的信息。一系列切片可以數(shù)字“疊加”在一起,形成患者的3D圖像,從而更容易識別和定位基本結(jié)構(gòu)以及可能的腫瘤或異常。

步驟如下,首先開始下載2019年和2020年的數(shù)據(jù)集。

然后,將這兩個數(shù)據(jù)集合并到它們的訓(xùn)練、驗證和測試文件夾中。下一步是讀取CT掃描圖像,并將CT掃描圖像的每個切片轉(zhuǎn)換為一系列PNG原始圖像和遮罩。后來,使用這個Github倉庫中的UNet模型,并訓(xùn)練了一個UNet模型。

數(shù)據(jù)理解:在開始數(shù)據(jù)處理和訓(xùn)練之前,我想加載幾個NIfTI文件,以便更熟悉它們的3D數(shù)據(jù)結(jié)構(gòu),能夠可視化它們并從圖像中提取元數(shù)據(jù)。

下載完VerSe數(shù)據(jù)集后,我打開了一個*.nii.gz*文件。通過讀取一個文件并查看CT掃描圖像的一個特定切片,我能夠運行Numpy transpose功能,以軸向、矢狀和冠狀三種不同視圖查看一個切片。

在對原始圖像更加熟悉,能夠從原始3D圖像中提取一個切片后,現(xiàn)在是時候查看同一切片的遮罩文件了。

正如你在下面的圖片中所看到的,能夠?qū)⒄谡智衅采w在原始圖像切片上。我們在這里看到漸變色的原因是,遮罩文件不僅存在每個脊柱的定義區(qū)域,而且它們還有不同的標(biāo)簽(用不同的顏色顯示),以及每個脊柱的編號或標(biāo)簽。為了更好地理解脊柱標(biāo)記,你可以參考本頁。

數(shù)據(jù)準(zhǔn)備:

數(shù)據(jù)準(zhǔn)備的任務(wù)是從原始圖像和遮罩文件中的每個3D CT掃描文件生成圖像切片。

它首先使用NiBabel庫讀取“.zip”格式的原始圖像和遮罩圖像,并將其轉(zhuǎn)換為Numpy數(shù)組。然后檢查每個3D圖像,檢查每個圖像的視角,并嘗試將大部分圖像轉(zhuǎn)換為矢狀視圖。

接下來,我從每個切片生成PNG文件,并將其存儲為“L”格式,即灰度值。在這種情況下,我們不需要生成RGB圖像。

在這個任務(wù)中,使用了UNet體系結(jié)構(gòu),以便能夠在數(shù)據(jù)集上應(yīng)用語義分割。為了更好地了解UNet和語義切分,建議查看這個博客。

使用了Pytorch和Pytorchvision來完成這項任務(wù)。正如提到的,這個倉庫使用PyTorch很好地實現(xiàn)了UNet,一直在使用它的一些代碼。

由于正在使用NIfTI文件,并且為了能夠在python中讀取這些文件,將使用NiBabel庫。NiBabel是一個python庫,用于讀取和寫入一些常見的醫(yī)學(xué)和神經(jīng)成像文件格式,如NIfTI文件。

Dice分數(shù):為了評估我們的模型在語義分割任務(wù)中的表現(xiàn),我們可以使用Dice分數(shù)。Dice系數(shù)是2*重疊區(qū)域(預(yù)測遮罩區(qū)域和真實遮罩區(qū)域之間)除以兩幅圖像中的像素總數(shù)。

訓(xùn)練:首先我定義了UNet類,然后定義了PyTorch數(shù)據(jù)集類,其中包括讀取和預(yù)處理圖像。預(yù)處理任務(wù)包括加載PNG文件,將它們?nèi)空{(diào)整為一個大。ㄔ诒纠袨250x250),并將它們?nèi)哭D(zhuǎn)換為NumPy數(shù)組,然后再轉(zhuǎn)換為PyTorch張量。

通過調(diào)用dataset類(VerSeDataset),我們可以在我定義的批內(nèi)準(zhǔn)備數(shù)據(jù)。為了確保原始圖像和遮罩圖像之間的映射是正確的,我調(diào)用next(iter(valid_dataloader))來獲取批次中的下一個項目并將其可視化。

后來將模型定義為model=UNet(n_channels=1,n_classes=1)。通道數(shù)是1,因為有一個灰度圖像而不是RGB,如果你的圖像是RGB圖像,你可以將n_channels改為3。類的數(shù)量是1,因為只有一個類來判斷一個像素是否是脊柱的一部分。如果你的問題是多類分割,你可以將類的數(shù)量設(shè)置為你有多少個類。

后來,訓(xùn)練了模型。對于每個批次,首先計算損失值,通過反向傳播更新參數(shù)。后來再次檢查了所有批次,只計算了驗證數(shù)據(jù)集的損失,并存儲了損失值。接下來,對train和validation的損失值進行了可視化觀察,并跟蹤了模型的性能。

保存模型后,能夠抓取其中一張圖像并將其傳遞給經(jīng)過訓(xùn)練的模型,并收到一張預(yù)測的遮罩圖像。通過將原始、真實蒙版和預(yù)測蒙版的三幅圖像并排繪制,能夠直觀地評估結(jié)果。

從上圖可以看出,模型在矢狀面和軸向視圖上都表現(xiàn)得非常好,因為預(yù)測的遮罩與真實的遮罩區(qū)域非常相似。

完整的代碼:

作者:Mazi Boustani

日期:2021年12月24日

目的:使用PyTorch訓(xùn)練UNet模型,使其能夠使用VerSe數(shù)據(jù)集分割脊柱

import numpy as np 

import pandas as pd

import os

from os import listdir

from os.path import splitext

import glob

import shutil

import random

from pathlib import Path

from PIL import Image

from tqdm import tqdm

import matplotlib.pyplot as plt

%matplotlib inline

try:

   import nibabel as nib

except:

   raise ImportError('Install NIBABEL')


import torch

import torch.nn as nn

from torch import Tensor

import torch.nn.functional as F

from torch import optim

import torchvision.transforms as T

from torch.utils.data import DataLoader, random_split

from torch.utils.data import Dataset


# set folder paths for train and validation data

data_folder_path = "/Users/mazi/Projects/other/CT/data"

train_data = data_folder_path + "/verse_19_20_training/"

validation_data = data_folder_path + "/verse_19_20_validation/"


數(shù)據(jù)理解

# get one image to load


train_data_raw_image = train_data + "/rawdata/sub-verse521/sub-verse521_dir-ax_ct.nii.gz"

one_image = nib.load(train_data_raw_image)

# look at image shape

print(one_image.shape)

# look at image header. To understand header please refer to: https://brainder.org/2012/09/23/the-nifti-file-format/

print(one_image.header)

# look at the raw data

one_image_data = one_image.get_fdata()

print(one_image_data)


# Visualize one image in three different angles

one_image_data_axial = one_image_data

# change the view

one_image_data_sagittal = np.transpose(one_image_data, [2,1,0])

one_image_data_sagittal = np.flip(one_image_data_sagittal, axis=0)

# change the view

one_image_data_coronal = np.transpose(one_image_data, [2,0,1])

one_image_data_coronal = np.flip(one_image_data_coronal, axis=0)

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(one_image_data_axial[:,:,10], cmap ='bone')

ax[0].set_title("Axial view", fontsize=60)

ax[1].imshow(one_image_data_sagittal[:,:,260], cmap ='bone')

ax[1].set_title("Sagittal view", fontsize=60)

ax[2].imshow(one_image_data_coronal[:,:,200], cmap ='bone')

ax[2].set_title("Coronal view", fontsize=60)

plt.show()

# Overlay a mask on top of raw image (one slice of CT-scan)

train_data_mask_image = train_data + "derivatives/sub-verse521/sub-verse521_dir-ax_seg-vert_msk.nii.gz"

train_data_mask_image = nib.load(train_data_mask_image).get_fdata()

plt.figure(figsize=(10,10))

rotated_raw = np.transpose(one_image_data, [2,1,0])

rotated_raw = np.flip(rotated_raw, axis=0)

plt.imshow(rotated_raw[:,:,260], cmap ='bone', interpolation='none')

train_data_mask_image[train_data_mask_image == 0 ] = np.nan

rotated_mask = np.transpose(train_data_mask_image, [2,1,0])

rotated_mask = np.flip(rotated_mask, axis=0)

plt.imshow(rotated_mask[:,:,260], cmap ='cool')

預(yù)處理數(shù)據(jù)

# Set paths to store processed train and validation raw images and masks

processed_train = "./processed_train/"

processed_validation = "./processed_validation/"


processed_train_raw_images = processed_train + "raw_images/"

processed_train_masks = processed_train + "masks/"


processed_validation_raw_images = processed_validation + "raw_images/"

processed_validation_masks = processed_validation + "masks/"

# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_


# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_files)))


# Read all 2019 and 2020 derivatives files, both train and validation

masks_train_files = glob.glob(os.path.join(train_data, 'derivativesnii.gz'))

masks_validation_files = glob.glob(os.path.join(validation_data, 'derivativesnii.gz'))


print("Masks images count train: {0}, validation: {1}".format(len(masks_train_files), len(masks_validation_files)))


def read_file(nii_file):

   '''

   Read .nii.gz file.

   Args:

     nii_file (str): a file path.
   

   Return:

     3D numpy array of CT image data.

   '''
   

   return np.a(chǎn)sanyarray(nib.load(nii_file).dataobj)


def save_file(raw_data, label_data, file_name, index, output_raw_file_path, output_label_file_path):

   '''

   Save file into npz format.


   Args:

     raw_data (array): 2D numpy array of raw image data.

     label_data (array): 2D numpy array of label image data.

     file_name (str): file name.

     index (int): slice of CT image.

     output_raw_file_path (str): Path to all raw files.

     output_label_file_path (str): Path to all mask files.

   '''

   # replace all non-zero pixels to 1

   label_data = np.where(label_data > 0, 1, label_data)

   unique_values = np.unique(label_data)
   

   # if data has pixel with value of 1 means it is a positive datapoint

   if len(unique_values) > 1:


       raw_file_name = "{0}{1}_{2}.png".format(output_raw_file_path, file_name, index)

       im = Image.fromarray(raw_data)

       im = im.convert("L")

       im.save(raw_file_name)

       label_file_name = "{0}{1}_{2}.png".format(output_label_file_path, file_name, index)

       im = Image.fromarray(label_data)

       im = im.convert("L")

       im.save(label_file_name)


def is_diagonal(matrix):

   '''

   Check if givem matrix is diagonal or not.

   

   Args:

       matrix (np array): numpy array

   '''

   for i in range(0, 3):

       for j in range(0, 3) :

           if ((i 。 j) and (matrix[i][j] != 0)):

               return False

   return True


def generate_data(raw_file, label_file, file_name, output_raw_file_path, output_label_file_path):

   '''

   Main function to read each raw and label file and generate series of images

   per each slice.

   Args:

     raw_file (str): path to raw file.

     label_file (str): path to label file.

     file_name (str): file name.

     output_raw_file_path (str): Path to all raw files.

     output_label_file_path (str): Path to all mask files.

   '''
   

   # If skip every 2 slice. Adjacent slices can be very similar to each other and

   # will generate redundant data

   skip_slice = 3

   continue_it = True

   raw_data = read_file(raw_file)

   label_data = read_file(label_file)

   if "split" in raw_file:

       continue_it = False
   

   affine = nib.load(raw_file).a(chǎn)ffine


   if is_diagonal(affine[:3, :3]):

       transposed_raw_data = np.transpose(raw_data, [2,1,0])

       transposed_raw_data = np.flip(transposed_raw_data)

       transposed_label_data = np.transpose(label_data, [2,1,0])

       transposed_label_data = np.flip(transposed_label_data)


   else:

       transposed_raw_data = np.rot90(raw_data)

       transposed_raw_data = np.flip(transposed_raw_data)

       transposed_label_data = np.rot90(label_data)

       transposed_label_data = np.flip(transposed_label_data) 

   if continue_it:

       if transposed_raw_data.shape:

           slice_count = transposed_raw_data.shape[-1]

           print("File name: ", file_name, " - Slice count: ", slice_count)

           # skip some slices

           for each_slice in range(1, slice_count, skip_slice):

              save_file(transposed_raw_data[:,:,each_slice],

                        transposed_label_data[:,:,each_slice],

                         file_name,

                         each_slice,

                         output_raw_file_path,

                         output_label_file_path)


# Loop over raw images and masks and generate 'PNG' images.

print("Processing started.")

for each_raw_file in raw_train_files:

   raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

   for each_mask_file in masks_train_files:

       if raw_file_name in each_mask_file.split("/")[-1]:

           generate_data(each_raw_file,
                         each_mask_file,
                         raw_file_name,
                         processed_train_raw_images,
                         processed_train_masks)

print("Processing train data done.")

# Loop over raw images and masks and generate 'PNG' images.

for each_raw_file in raw_validation_files:

   raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

   for each_mask_file in masks_validation_files:

       if raw_file_name in each_mask_file.split("/")[-1]:

           generate_data(each_raw_file,
                         each_mask_file,
                         raw_file_name,
                         processed_validation_raw_images,
                         processed_validation_masks)

print("Processing validation data done.")


訓(xùn)練

# Define model parameters

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# image size to convert to

IMAGE_HEIGHT = 250

IMAGE_WIDTH = 250

LEARNING_RATE = 1e-4

BATCH_SIZE = 10

EPOCHS = 10

NUM_WORKERS = 8


# Set the device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# UNet model parts

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

class DoubleConv(nn.Module):

   """(convolution => [BN] => ReLU) * 2"""

   def __init__(self, in_channels, out_channels, mid_channels=None):

       super().__init__()

       if not mid_channels:

           mid_channels = out_channels

       self.double_conv = nn.Sequential(

           nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),

           nn.BatchNorm2d(mid_channels),

           nn.ReLU(inplace=True),

           nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),

           nn.BatchNorm2d(out_channels),

           nn.ReLU(inplace=True)

       )

   def forward(self, x):

       return self.double_conv(x)


class Down(nn.Module):

   """Downscaling with maxpool then double conv"""

   def __init__(self, in_channels, out_channels):

       super().__init__()

       self.maxpool_conv = nn.Sequential(

           nn.MaxPool2d(2),

           DoubleConv(in_channels, out_channels)

       )

   def forward(self, x):

       return self.maxpool_conv(x)


class Up(nn.Module):

   """Upscaling then double conv"""

   def __init__(self, in_channels, out_channels, bilinear=True):

       super().__init__()

      

 # if bilinear, use the normal convolutions to reduce the number of channels

       if bilinear:
           self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
           self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
       else:
           self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
           self.conv = DoubleConv(in_channels, out_channels)


   def forward(self, x1, x2):

       x1 = self.up(x1)

       # input is CHW

       diffY = x2.size()[2] - x1.size()[2]

       diffX = x2.size()[3] - x1.size()[3]


       x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,

                       diffY // 2, diffY - diffY // 2])
       

       # if you have padding issues, see

       # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a

       # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

       x = torch.cat([x2, x1], dim=1)

       return self.conv(x)

class OutConv(nn.Module):

   def __init__(self, in_channels, out_channels):

       super(OutConv, self).__init__()

       self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)


   def forward(self, x):

       return self.conv(x)


# Defining UNet architecture

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

class UNet(nn.Module):

   def __init__(self, n_channels, n_classes, bilinear=True):

       super(UNet, self).__init__(

       self.n_channels = n_channels

       self.n_classes = n_classes

       self.bilinear = bilinear

       self.inc = DoubleConv(n_channels, 64)

       self.down1 = Down(64, 128)

       self.down2 = Down(128, 256)

       self.down3 = Down(256, 512)

       factor = 2 if bilinear else 1

       self.down4 = Down(512, 1024 // factor)

       self.up1 = Up(1024, 512 // factor, bilinear)

       self.up2 = Up(512, 256 // factor, bilinear)

       self.up3 = Up(256, 128 // factor, bilinear)

       self.up4 = Up(128, 64, bilinear)

       self.outc = OutConv(64, n_classes)

   def forward(self, x):

       x1 = self.inc(x)

       x2 = self.down1(x1)

       x3 = self.down2(x2)

       x4 = self.down3(x3)

       x5 = self.down4(x4)

       x = self.up1(x5, x4)

       x = self.up2(x, x3)

       x = self.up3(x, x2)

       x = self.up4(x, x1)

       logits = self.outc(x)

       return logits


# Define PyTorch dataset class

# This class will access the images and masks, preprocess them for training and validation


class VerSeDataset(Dataset):

   def __init__(self, raw_images_path, masks_path, images_name):

       self.raw_images_path = raw_images_path

       self.masks_path = masks_path

       self.images_name = images_name
       

   def __len__(self):

       return len(self.images_name)

   def __getitem__(self, index):
       

       # get image and mask for a given index

       img_path = os.path.join(self.raw_images_path, self.images_name[index])

       mask_path = os.path.join(self.masks_path, self.images_name[index])
       

       # read the image and mask

       image = Image.open(img_path)

       mask = Image.open(mask_path)
       

       # resize image and change the shape to (1, image_width, image_height)

       w, h = image.size

       image = image.resize((w, h), resample=Image.BICUBIC)

       image = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(image)

       image_ndarray = np.a(chǎn)sarray(image)

       image_ndarray = image_ndarray.reshape(1, image_ndarray.shape[0], image_ndarray.shape[1])

       # resize the mask. Mask shape is (image_width, image_height)

       mask = mask.resize((w, h), resample=Image.NEAREST)

       mask = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(mask)

       mask_ndarray = np.a(chǎn)sarray(mask)
       

       return {

           'image': torch.a(chǎn)s_tensor(image_ndarray.copy()).float().contiguous(),

           'mask': torch.a(chǎn)s_tensor(mask_ndarray.copy()).float().contiguous(

      }


# Get path for all images and masks

train_images_paths = os.listdir(processed_train_raw_images)

train_masks_paths = os.listdir(processed_train_masks)

validation_images_paths = os.listdir(processed_validation_raw_images)

validation_masks_paths = os.listdir(processed_validation_masks)


# Load both images and masks data

train_data = VerSeDataset(processed_train_raw_images, processed_train_masks, train_images_paths)

valid_data = VerSeDataset(processed_validation_raw_images, processed_validation_masks, validation_images_paths)

# Create PyTorch DataLoader

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)


# Looking at one image and mask from one batch just to check them visually

next_image = next(iter(valid_dataloader))

fig, ax = plt.subplots(1, 2, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("Mask image", fontsize=60)

plt.show()

# Defining Dice loss class

# Source code: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch

class DiceLoss(nn.Module):

   def __init__(self, weight=None, size_average=True):

       super(DiceLoss, self).__init__()

   def forward(self, inputs, targets, smooth=1):


       inputs = torch.sigmoid(inputs)      
       

       # flatten label and prediction tensors

       inputs = inputs.view(-1)

       targets = targets.view(-1)
       

       intersection = (inputs * targets).sum()                            

       dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
       

       bce = F.binary_cross_entropy_with_logits(inputs, targets)

       pred = torch.sigmoid(inputs)

       loss = bce * 0.5 + dice * (1 - 0.5)
       

      # subtract 1 to calculate loss from dice value

       return 1 - dice


# Define model as UNet


model = UNet(n_channels=1, n_classes=1)

model.to(device=device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# Train and validate


train_loss = []

val_loss = []

for epoch in range(EPOCHS):  
   

   model.train()

   train_running_loss = 0.0

   counter = 0
   

   with tqdm(total=len(train_data), desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='img') as pbar:

       for batch in train_dataloader:
           

           counter+=1

           image = batch['image'].to(DEVICE)

           mask = batch['mask'].to(DEVICE)
           

           optimizer.zero_grad()

           outputs = model(image)

           outputs = outputs.squeeze(1)

           loss = DiceLoss()(outputs, mask)

           train_running_loss += loss.item()

           loss.backward()

           optimizer.step()
           

           pbar.update(image.shape[0])

           pbar.set_postfix(**{'loss (batch)': loss.item()})

       train_loss.a(chǎn)ppend(train_running_loss/counter)
   

   model.eval()

   valid_running_loss = 0.0

   counter = 0

   with torch.no_grad():

       for i, data in enumerate(valid_dataloader):

           counter += 1
           

           image = data['image'].to(DEVICE)

           mask = data['mask'].to(DEVICE)

           outputs = model(image)

           outputs = outputs.squeeze(1)
           

           loss = DiceLoss()(outputs, mask)

           valid_running_loss += loss.item()
           

       val_loss.a(chǎn)ppend(valid_running_loss)


Epoch 1/10: 100%|██████████| 4790/4790 [4:00:34<00:00,  3.01s/img, loss (batch)=0.385]  

Epoch 2/10: 100%|██████████| 4790/4790 [4:00:02<00:00,  3.01s/img, loss (batch)=0.268]  

Epoch 3/10: 100%|██████████| 4790/4790 [3:57:30<00:00,  2.98s/img, loss (batch)=0.152]  

Epoch 4/10: 100%|██████████| 4790/4790 [3:57:05<00:00,  2.97s/img, loss (batch)=0.105]  

Epoch 5/10: 100%|██████████| 4790/4790 [4:08:29<00:00,  3.11s/img, loss (batch)=0.103]   

Epoch 6/10: 100%|██████████| 4790/4790 [4:04:12<00:00,  3.06s/img, loss (batch)=0.0874]  

Epoch 7/10: 100%|██████████| 4790/4790 [4:02:00<00:00,  3.03s/img, loss (batch)=0.0759]  

Epoch 8/10: 100%|██████████| 4790/4790 [3:58:32<00:00,  2.99s/img, loss (batch)=0.0655]  

Epoch 9/10: 100%|██████████| 4790/4790 [4:00:47<00:00,  3.02s/img, loss (batch)=0.0644]  

Epoch 10/10: 100%|██████████| 4790/4790 [4:08:54<00:00,  3.12s/img, loss (batch)=0.0604]  


# Plot train vs validation loss

plt.figure(figsize=(10, 7))

plt.plot(train_loss, color="orange", label='train loss')

plt.plot(val_loss, color="red", label='validation loss')

plt.xlabel("Epochs")

plt.ylabel("Loss")

plt.legend()

plt.show()

# Save the trained model

torch.save({

   'epoch': EPOCHS,

   'model_state_dict': model.state_dict(),

   'optimizer_state_dict': optimizer.state_dict(),
}, "./unet_model.pth")


# Visually look at one prediction 

next_image = next(iter(valid_dataloader))

# do predict

outputs = model(next_image['image'].float())

outputs = outputs.detach().cpu()

loss = DiceLoss()(outputs, next_image['mask'])

print("Dice Score: ", 1- loss.item())

outputs[outputs<=0.0] = 0

outputs[outputs>0.0] = 1.0

# plot all three images

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw Image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("True Mask", fontsize=60)

ax[2].imshow(outputs[0,0,:,:], cmap ='bone')

ax[2].set_title("Predicted Mask", fontsize=60)

plt.show()

未來的工作:這個任務(wù)也可以用3D UNet完成,這可能是學(xué)習(xí)脊柱結(jié)構(gòu)的更好方法。

因為我們對每個椎骨的每個遮罩區(qū)域都有標(biāo)簽,所以我們可以進一步進行多類遮罩分割。此外,當(dāng)圖像視圖為矢狀視圖時,模型性能最好,因此,將所有切片轉(zhuǎn)換為矢狀視圖可能會得到更好的結(jié)果。

感謝閱讀!

       原文標(biāo)題 : UNet分割脊柱

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

人工智能 獵頭職位 更多
掃碼關(guān)注公眾號
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯
x
*文字標(biāo)題:
*糾錯內(nèi)容:
聯(lián)系郵箱:
*驗 證 碼:

粵公網(wǎng)安備 44030502002758號