182 lines
6.6 KiB
Python
182 lines
6.6 KiB
Python
# Orthanc plugin for mammography
|
|
# Copyright (C) 2024 Edouard Chatzopoulos and Sebastien Jodogne,
|
|
# ICTEAM UCLouvain, Belgium
|
|
#
|
|
# This program is free software: you can redistribute it and/or
|
|
# modify it under the terms of the GNU Affero General Public License
|
|
# as published by the Free Software Foundation, either version 3 of
|
|
# the License, or (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful, but
|
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
# Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
|
|
import download
|
|
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms.v2 as transforms
|
|
|
|
from functools import partial
|
|
from torchvision.models.detection import RetinaNet
|
|
from torchvision.models.detection.anchor_utils import AnchorGenerator
|
|
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor
|
|
from torchvision.models.detection.retinanet import RetinaNetHead
|
|
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
|
from torchvision.ops import FrozenBatchNorm2d
|
|
from torchvision.ops.feature_pyramid_network import LastLevelP6P7
|
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
MODELS_DIR = os.path.join(SCRIPT_DIR, 'models')
|
|
|
|
os.makedirs(MODELS_DIR, exist_ok = True)
|
|
|
|
download.get(os.path.join(MODELS_DIR, 'resnet50-11ad3fa6.pth'),
|
|
'https://orthanc.uclouvain.be/downloads/cross-platform/orthanc-mammography/models/2024-03-08-resnet50-11ad3fa6.pth',
|
|
102540417, '012571d812f34f8442473d8b827077b5')
|
|
|
|
download.get(os.path.join(MODELS_DIR, 'retina_res50_trained_08_03.pth'),
|
|
'https://orthanc.uclouvain.be/downloads/cross-platform/orthanc-mammography/models/2024-03-08-retina_res50_trained_08_03.pth',
|
|
145735292, '53aa159ea0b83234d767aacb43619748')
|
|
|
|
|
|
class ResizeBetter:
|
|
def __init__(self, min_size=1750):
|
|
self.min_size = min_size
|
|
|
|
def __call__(self, sample):
|
|
shape = sample[0].shape[-2:]
|
|
return transforms.Resize((self.min_size,int(self.min_size*shape[1]/shape[0])),
|
|
interpolation=transforms.InterpolationMode.BILINEAR,
|
|
antialias=True)(sample)
|
|
|
|
|
|
def anchorgen():
|
|
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
|
|
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
|
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
|
|
return anchor_generator
|
|
|
|
|
|
def load_model(config, pretrained_path, mean=0, std=1):
|
|
if False:
|
|
# Download backbone from Internet
|
|
model_backbone = resnet50(norm_layer = FrozenBatchNorm2d,
|
|
weights=ResNet50_Weights.DEFAULT)
|
|
else:
|
|
weights = torch.load(os.path.join(MODELS_DIR, 'resnet50-11ad3fa6.pth'), map_location=torch.device('cpu'))
|
|
model_backbone = resnet50(norm_layer = FrozenBatchNorm2d)
|
|
model_backbone.load_state_dict(weights)
|
|
|
|
model_backbone.fc = model_backbone.fc = nn.Sequential(
|
|
nn.Linear(4*512 , config["num_classes"])
|
|
)
|
|
|
|
model_backbone = _resnet_fpn_extractor(model_backbone,
|
|
config["trainable_backbone_layers"],
|
|
returned_layers=[2, 3, 4],
|
|
extra_blocks=LastLevelP6P7(2048, 256))
|
|
|
|
anchor_generator = anchorgen()
|
|
|
|
head = RetinaNetHead(
|
|
model_backbone.out_channels,
|
|
anchor_generator.num_anchors_per_location()[0],
|
|
config["num_classes"],
|
|
norm_layer=partial(nn.GroupNorm, 32),
|
|
)
|
|
head.regression_head._loss_type = "giou"
|
|
model = RetinaNet(model_backbone,
|
|
num_classes=config["num_classes"],
|
|
anchor_generator=anchor_generator,
|
|
head=head,
|
|
min_size=config["min_size"] ,
|
|
max_size=config["max_size"],
|
|
image_mean=[mean, mean, mean],
|
|
image_std=[std, std, std],
|
|
fg_iou_thresh=config["fg_iou_thresh"],
|
|
bg_iou_thresh=config["bg_iou_thresh"],
|
|
nms_thresh=config["nms_thresh"],
|
|
_skip_resize=True)
|
|
|
|
if pretrained_path is not None:
|
|
state_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
|
|
model.load_state_dict(state_dict)
|
|
print(f"Model loaded from checkpoint {pretrained_path}")
|
|
|
|
for layer in model.modules():
|
|
if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.BatchNorm1d):
|
|
layer.eval() # Set to evaluation mode
|
|
layer.weight.requires_grad = False
|
|
layer.bias.requires_grad = False
|
|
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_retina_net():
|
|
config = {
|
|
'num_classes' : 2,
|
|
'min_size' : 2048,
|
|
'max_size' : 2048,
|
|
'trainable_backbone_layers' : 0,
|
|
'fg_iou_thresh' : 0.5,
|
|
'bg_iou_thresh' : 0.4,
|
|
'nms_thresh' : 0.3,
|
|
}
|
|
|
|
return {
|
|
'min_size' : config['min_size'],
|
|
'eval' : load_model(config, os.path.join(MODELS_DIR, 'retina_res50_trained_08_03.pth')),
|
|
}
|
|
|
|
|
|
def dicom_to_tensor(dicom, min_size):
|
|
assert(len(dicom.pixel_array.shape) == 2)
|
|
|
|
#Normalize the value scale to 0-255 (Useful for some processing steps)
|
|
im_array = np.stack((dicom.pixel_array,)*3, axis=-1)
|
|
im_max = np.max(im_array)
|
|
im_min = np.min(im_array)
|
|
image_tensor = torch.tensor(im_array.astype(np.float32).transpose(2, 0, 1))
|
|
|
|
#Resize longest side to 2048 (with same ratio) and normalize
|
|
image_tensor = ResizeBetter(min_size) (image_tensor)
|
|
|
|
std = torch.std(image_tensor)
|
|
mean = torch.mean(image_tensor)
|
|
image_tensor = torch.sub(image_tensor, mean )
|
|
image_tensor = torch.div(image_tensor, std)
|
|
|
|
assert(len(image_tensor.shape) == 3)
|
|
|
|
return image_tensor
|
|
|
|
|
|
def apply_model_to_dicom(model, dicom, rescale_boxes=True):
|
|
image_tensor = dicom_to_tensor(dicom, model['min_size'])
|
|
output = model['eval'] ([ image_tensor ])
|
|
|
|
assert(len(output) == 1)
|
|
output = output[0]
|
|
|
|
if rescale_boxes:
|
|
originalWidth = dicom.pixel_array.shape[1]
|
|
originalHeight = dicom.pixel_array.shape[0]
|
|
resizedWidth = image_tensor.shape[2]
|
|
resizedHeight = image_tensor.shape[1]
|
|
|
|
# TODO - The "int()" in ResizeBetter is anisotropic
|
|
ratio = originalWidth / resizedWidth
|
|
output['boxes'] *= ratio
|
|
|
|
return output
|