Official code for Dual-Space Cold-Start Active Learning Guided by SAM3 for Medical Image Segmentation.
This repository follows a simple workflow: prepare data, extract SAM3-guided information, generate active-learning plans, and train the segmentation model.
Data processing scripts are under data_preprocess/.
data_preprocess/data_proprocessed.py splits PROMISE12-style 3D .mhd cases into 70% train, 10% validation, and 20% test at the case level. It converts each volume into 2D .npy slices, saving images as *_image.npy, masks as *_segmentation.npy, and split files as train.csv, valid.csv, and test.csv.
data_preprocess/extract_feature.py extracts SAM3 visual features for candidate training images. The feature file keeps the same stem as the source image so it can be matched during query-plan generation.
data_preprocess/sam3_pseudo_labeling.py generates SAM3 pseudo masks from image slices using a text prompt such as prostate.
For DSAS, data_preprocess/mask_quality/ computes mask geometry and SAM3 attention quality. The final file is property/mask_properties.csv, containing fields such as size, center_x, center_y, foreground_ratio, complexity_score, and attn_max_p.
Query strategies are implemented in sample_selection/query_strategies/.
Feature-only baselines can be generated with:
python3 sample_selection/generate_plans.pyDSAS plans are generated with:
python3 sample_selection/generate_dsas.py \
--budgets 15 30 60 100 \
--feature-dir /path/to/train_slices_sam3_feature \
--base-image-dir /path/to/all_slices \
--score-csv /path/to/property/mask_properties.csv \
--save-dir /path/to/plansEach generated plan is a CSV with:
image,segmentation
/path/to/image.npy,/path/to/segmentation.npyThe plan can be used directly as the training CSV.
Training is handled by train/train.py. The training CSV can be either the full split or a generated active-learning plan.
Train with the full split:
python3 train/train.py \
--train_csv /path/to/data/promise/mydata/splits/train.csv \
--val_csv /path/to/data/promise/mydata/splits/valid.csv \
--pretrain_weights /path/to/checkpoint/backbone.safetensors \
--output_dir /path/to/train/runs \
--exp_name full_trainTrain with a DSAS plan:
python3 train/train.py \
--train_csv /path/to/plans/plan_dsas_pen2_mdiv1.0_qth0.7_budget30.csv \
--val_csv /path/to/data/promise/mydata/splits/valid.csv \
--pretrain_weights /path/to/checkpoint/backbone.safetensors \
--output_dir /path/to/train/runs \
--exp_name dsas_budget30The trainer freezes the SAM3 visual backbone and optimizes the segmentation layers using Dice loss plus cross-entropy loss. Each run saves logs, TensorBoard files, latest_checkpoint.pth, and best_checkpoint.pth under the experiment directory.
Several scripts still contain hard-coded paths from the original experiment environment. Update those paths before running the pipeline on a new machine.
