diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py index 81a890aa68..3faef84917 100644 --- a/monai/metrics/froc.py +++ b/monai/metrics/froc.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any import numpy as np import torch @@ -67,12 +67,14 @@ def compute_fp_tp_probs_nd( hittedlabel = evaluation_mask[tuple(coords.T)] fp_probs = probs[np.where(hittedlabel == 0)] + num_targets = 0 for i in range(1, max_label + 1): - if i not in labels_to_exclude and i in hittedlabel: - tp_probs[i - 1] = probs[np.where(hittedlabel == i)].max() + if i not in labels_to_exclude: + num_targets += 1 + if i in hittedlabel: + tp_probs[i - 1] = probs[np.where(hittedlabel == i)].max() - num_targets = max_label - len(labels_to_exclude) - return fp_probs, tp_probs, cast(int, num_targets) + return fp_probs, tp_probs, num_targets def compute_fp_tp_probs( diff --git a/tests/metrics/test_compute_froc.py b/tests/metrics/test_compute_froc.py index 4dc0507366..aa889ddb07 100644 --- a/tests/metrics/test_compute_froc.py +++ b/tests/metrics/test_compute_froc.py @@ -60,6 +60,34 @@ 3, ] +TEST_CASE_EXCLUDE_ABSENT = [ + { + "probs": torch.tensor([1, 0.6, 0.8]), + "y_coord": torch.tensor([0, 2, 3]), + "x_coord": torch.tensor([3, 0, 1]), + "evaluation_mask": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]), + "labels_to_exclude": [5], + "resolution_level": 0, + }, + np.array([0.6]), + np.array([1, 0, 0.8]), + 3, +] + +TEST_CASE_EXCLUDE_DUPLICATE = [ + { + "probs": torch.tensor([1, 0.6, 0.8]), + "y_coord": torch.tensor([0, 2, 3]), + "x_coord": torch.tensor([3, 0, 1]), + "evaluation_mask": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]), + "labels_to_exclude": [2, 2], + "resolution_level": 0, + }, + np.array([0.6]), + np.array([1, 0, 0.8]), + 2, +] + TEST_CASE_4 = [ { "fp_probs": np.array([0.8, 0.6]), @@ -112,7 +140,9 @@ class TestComputeFpTp(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + @parameterized.expand( + [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_EXCLUDE_ABSENT, TEST_CASE_EXCLUDE_DUPLICATE] + ) def test_value(self, input_data, expected_fp, expected_tp, expected_num): fp_probs, tp_probs, num_tumors = compute_fp_tp_probs(**input_data) np.testing.assert_allclose(fp_probs, expected_fp, rtol=1e-5)