Multimodal Xai Prostate
This is the official documentation for the "Multimodal XAI Prostate" model, derived from the study "Integrating Multimodal Deep Learning and Explainable AI for Enhanced Prostate Lesion Classification" by Claudio Giovannoni et al. For complete author's information, dataset, abstract and workflow, please referd to the related online repository available at GitHub's Official Website.
Description
Multimodal XAI Prostate is a Python package that implements a complete pipeline for explainable multimodal prostate MRI lesion classification by fusing three imaging modalities (ADC, T2w, DWI) with clinical metadata in tabular modality. This, while enhancing the model's performance in prediction of prostate cancer, allows for both a global and a local level of model interpretability.
It provides:
- Multimodal Model training & tracking: possibility to train the multimodal model or load the pre-trained best model
- Multimodal Explainability: visual-based local XAI with Grad-CAM heatmaps and quantitative-based global XAI via SHAP attributions for tabular features
Derived from “Integrating Multimodal Deep Learning and Explainable AI for Enhanced Prostate Lesion Classification”
this implementation wraps the original network in open and transparent package shipped as a Tango artifact under the name multimodal_xai_prostate.
Once registered, the Tango driver
exposes a single run entrypoint that:
- Accepts a dict of raw image bytes (ADC, T2w, DWI) and a dict of clinical metadata
- Performs preprocessing, model inference, and XAI post-processing in one call
- Returns predictions, probabilities, multimodal explanations (local, visual-based and local/global quantitative-based) in image and array format
Inputs and Outputs
Inputs
Multimodal Model
- Imaging Data:
- Three Magnetic Resonance Imaging (MRI) modalities are used:
ADC(Apparent Diffusion Coefficient)T2w(T2-weighted Imaging)DWI(Diffusion Weighted Imaging)
- Imaging data are preprocessed from a 3-D NIFTI format, generating from 18 to 30 2-D PNG format slices per patient.
- All slices are cropped to a 64x64 Region of Interest (ROI) around the prostate area
- For each patient, only 5 diagnostically relevant slices are extracted from each modality (15 total per patient).
- The three modalities are combined with tabular data (transformed into a 64×64 matrix) and stacked as four channels into an array
- The patient-level arrays undergo a preprocess and augmentation step
- The patient-level arrays are stacked into a tensor of shape
(batch_size, 64, 64, 3), which is the input to the model.
- Three Magnetic Resonance Imaging (MRI) modalities are used:
Unimodal Model
- Tabular Metadata:
- The tabular features include:
patient_agePSA(Prostate-Specific Antigen) levelPSAD(PSA Density)prostate_volume
- The tabular features include:
- Tabular Model:
- The tabular model consists of a Fully Connected Neural Network processing the tabular clinical metadata
- Input data format: Pandas DataFrame containing patient clinical features
- Features are preprocessed, normalized, and converted to arrays before being fed to the model
- The model processes the same tabular features as the multimodal model but without the imaging data
- This separate model is specifically used for generating SHAP explanations of the tabular features
Outputs
Prediction Outputs
- Binary Classification Result:
- The model returns a binary probability vector:
0: non-csPCa (non clinically significant Prostate Cancer, or benign lesion)1: csPCa (Clinically significant Prostate Cancer)
- This output corresponds to the predicted class and is derived from the final sigmoid activation layer of the CNN-based multimodal architecture.
- Format: NumPy array of shape
(batch_size,)containing binary values (0 or 1)
- The model returns a binary probability vector:
Explanation Outputs
-
Local Visual-based Explainability (Grad-CAM):
- Channel-specific heatmaps for each imaging modality (T2w, ADC, DWI)
- Average heatmap Combined visualization across all modalities
- Format:
- NumPy arrays of shape
(64, 64)containing activation values - PNG visualization files showing:
- Original image
- Heatmap overlay
- Combined visualization
- NumPy arrays of shape
- Organization:
- Outputs are organized by patient and modality
- Each patient has a dedicated directory containing all visualizations
-
Local and Global Quantitative-based Explainability (SHAP):
- Feature Importance:
- Bar chart showing the mean absolute SHAP values for each feature, with separate bars for csPCa and non-csPCa cases
- Single Instances Feature Impact:
- Scatter plot visualization showing the distribution of SHAP values vs. feature values, colored by class (csPCa vs. non-csPCa)
- Format:
- PNG visualization files
- SHAP values as NumPy arrays
- Feature Importance:
API Response Structure
When deployed as a service, the model returns a JSON response with the following structure:
{
"predictions": [0, 1, 0, ...], // Binary classification results
"explanations": {
"visual": {
"patient_0": {
"t2w": "path/to/heatmap.png",
"adc": "path/to/heatmap.png",
"dwi": "path/to/heatmap.png",
"avg": "path/to/heatmap.png"
},
// More patients...
},
"tabular": {
"global_importance": "path/to/feature_importance.png",
"local_impact": "path/to/feature_impact.png",
"shap_values": [...] // Raw SHAP values
}
}
}
A shallow list of Artifacts returned by each model run has the following structure:
.
├── explainers
│ ├── multimodal_explainer
│ └── tabular_explainer
├── explanations
│ ├── multimodal
│ └── tabular
├── models
│ ├── multimodal_model
│ └── tabular_model
├── signature.json
└── tabular_signature.json
Model signature
{
"inputs": [
{
"type": "tensor",
"tensor-spec": {
"dtype": "uint16",
"shape": [-1, 64, 64, 3]
}
}
],
"outputs": [
{
"type": "tensor",
"tensor-spec": {
"dtype": "int64",
"shape": [-1]
}
}
]
}
Build
Init project
Install requirements.
pip install -r requirements.txt
Run a development environment
Install requirements.
pip install -r requirements-dev.txt
If useful, install a local version for tango-interfaces.
pip uninstall tango-interfaces
pip install -e <path>/tango-interfaces/
Run tracking server and model registry
mlflow ui
Prepare the environment for training/serving
Export the following environment variables.
export MLFLOW_TRACKING_URI="http://127.0.0.1:5000"
export MLFLOW_EXPERIMENT_NAME="test"
Use these if the server requires authentication.
export MLFLOW_TRACKING_USERNAME=<username>
export MLFLOW_TRACKING_PASSWORD=<password>
Make a training run
Before running the model, make sure to use this (for bash terminal):
export PYTHONPATH=${PWD}/src/
Or this (for fish terminal):
export PYTHONPATH={$PWD}/src/
To make a training run without loading already trained best models, execute the following command (this operation takes some time as creates environment).
python -m multimodal_xai_prostate.main
The settable parameters are "skip-training" (boolean), which allows to directly load the best model, and "num_patients" (integer) to control the number of patients for which the explanations are generated. For example:
python -m multimodal_xai_prostate.main --no-training --num-patients=10
Running local model server
Install requirements.
pip install -r requirements-dev-modelserver.txt
Run a model server
Access experiment from tracking server web UI at http://127.0.0.1:5000.

Copy experiment runid.

Use the runid to set the corresponding environment variable.
export MLFLOW_RUN_ID=<runid>
Serve the model identified by the runid.
mlflow models serve -m runs:/$MLFLOW_RUN_ID/model --enable-mlserver -p 5000
Dealing with mutex errors on Mac OS 15.6.1
To date (07/10/2025) this model is not working correctly on Mac OS 15.6.1, presumably due to a sort of incompatibility
betwwen tensorflow and mlflow.
During the training phase, we are getting the following error:
libc++abi: terminating due to uncaught exception of type std::__1::system_error: mutex lock failed: Invalid argument
During Airflow execution the flow stucks after the following log:
{process_utils.py:194} INFO - [mutex.cc : 452] RAW: Lock blocking 0x600002ac54b8 @
The workaround process to train locally a model (e.g. towards Databricks) is the following.
Create (or start and access to) a dedicated Docker container based on Linux distribution.
# create
docker run -it --name xai python:3.11 /bin/bash
# start and access to
docker start xai
docker exec -it xai /bin/bash
Create a folder in it.
mkdir xai
cd xai
From Docker host, copy source code to container.
docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/src xai:/xai
docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/requirements.txt xai:/xai
docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/requirements-dev.txt xai:/xai
Inside the container, install requirements and train the model.
# install requirements
pip install -r requirements.txt
pip install -r requirements-dev.txt
# set variables in the script and run it to train the model
./run_train_model_staging_databricks.sh