Skip to main content

Multimodal Xai Prostate

This is the official documentation for the "Multimodal XAI Prostate" model, derived from the study "Integrating Multimodal Deep Learning and Explainable AI for Enhanced Prostate Lesion Classification" by Claudio Giovannoni et al. For complete author's information, dataset, abstract and workflow, please referd to the related online repository available at GitHub's Official Website.

Description

Multimodal XAI Prostate is a Python package that implements a complete pipeline for explainable multimodal prostate MRI lesion classification by fusing three imaging modalities (ADC, T2w, DWI) with clinical metadata in tabular modality. This, while enhancing the model's performance in prediction of prostate cancer, allows for both a global and a local level of model interpretability.

It provides:

  • Multimodal Model training & tracking: possibility to train the multimodal model or load the pre-trained best model
  • Multimodal Explainability: visual-based local XAI with Grad-CAM heatmaps and quantitative-based global XAI via SHAP attributions for tabular features

Derived from “Integrating Multimodal Deep Learning and Explainable AI for Enhanced Prostate Lesion Classification” this implementation wraps the original network in open and transparent package shipped as a Tango artifact under the name multimodal_xai_prostate.

Once registered, the Tango driver exposes a single run entrypoint that:

  • Accepts a dict of raw image bytes (ADC, T2w, DWI) and a dict of clinical metadata
  • Performs preprocessing, model inference, and XAI post-processing in one call
  • Returns predictions, probabilities, multimodal explanations (local, visual-based and local/global quantitative-based) in image and array format

Inputs and Outputs

Inputs

Multimodal Model

  • Imaging Data:
    • Three Magnetic Resonance Imaging (MRI) modalities are used:
      • ADC (Apparent Diffusion Coefficient)
      • T2w (T2-weighted Imaging)
      • DWI (Diffusion Weighted Imaging)
    • Imaging data are preprocessed from a 3-D NIFTI format, generating from 18 to 30 2-D PNG format slices per patient.
    • All slices are cropped to a 64x64 Region of Interest (ROI) around the prostate area
    • For each patient, only 5 diagnostically relevant slices are extracted from each modality (15 total per patient).
    • The three modalities are combined with tabular data (transformed into a 64×64 matrix) and stacked as four channels into an array
    • The patient-level arrays undergo a preprocess and augmentation step
    • The patient-level arrays are stacked into a tensor of shape (batch_size, 64, 64, 3), which is the input to the model.

Unimodal Model

  • Tabular Metadata:
    • The tabular features include:
      • patient_age
      • PSA (Prostate-Specific Antigen) level
      • PSAD (PSA Density)
      • prostate_volume
  • Tabular Model:
    • The tabular model consists of a Fully Connected Neural Network processing the tabular clinical metadata
    • Input data format: Pandas DataFrame containing patient clinical features
    • Features are preprocessed, normalized, and converted to arrays before being fed to the model
    • The model processes the same tabular features as the multimodal model but without the imaging data
    • This separate model is specifically used for generating SHAP explanations of the tabular features

Outputs

Prediction Outputs

  1. Binary Classification Result:
    • The model returns a binary probability vector:
      • 0: non-csPCa (non clinically significant Prostate Cancer, or benign lesion)
      • 1: csPCa (Clinically significant Prostate Cancer)
    • This output corresponds to the predicted class and is derived from the final sigmoid activation layer of the CNN-based multimodal architecture.
    • Format: NumPy array of shape (batch_size,) containing binary values (0 or 1)

Explanation Outputs

  1. Local Visual-based Explainability (Grad-CAM):

    • Channel-specific heatmaps for each imaging modality (T2w, ADC, DWI)
    • Average heatmap Combined visualization across all modalities
    • Format:
      • NumPy arrays of shape (64, 64) containing activation values
      • PNG visualization files showing:
        • Original image
        • Heatmap overlay
        • Combined visualization
    • Organization:
      • Outputs are organized by patient and modality
      • Each patient has a dedicated directory containing all visualizations
  2. Local and Global Quantitative-based Explainability (SHAP):

    • Feature Importance:
      • Bar chart showing the mean absolute SHAP values for each feature, with separate bars for csPCa and non-csPCa cases
    • Single Instances Feature Impact:
      • Scatter plot visualization showing the distribution of SHAP values vs. feature values, colored by class (csPCa vs. non-csPCa)
    • Format:
      • PNG visualization files
      • SHAP values as NumPy arrays

API Response Structure

When deployed as a service, the model returns a JSON response with the following structure:

{
"predictions": [0, 1, 0, ...], // Binary classification results
"explanations": {
"visual": {
"patient_0": {
"t2w": "path/to/heatmap.png",
"adc": "path/to/heatmap.png",
"dwi": "path/to/heatmap.png",
"avg": "path/to/heatmap.png"
},
// More patients...
},
"tabular": {
"global_importance": "path/to/feature_importance.png",
"local_impact": "path/to/feature_impact.png",
"shap_values": [...] // Raw SHAP values
}
}
}

A shallow list of Artifacts returned by each model run has the following structure:

.
├── explainers
│   ├── multimodal_explainer
│   └── tabular_explainer
├── explanations
│   ├── multimodal
│   └── tabular
├── models
│   ├── multimodal_model
│   └── tabular_model
├── signature.json
└── tabular_signature.json

Model signature

{
"inputs": [
{
"type": "tensor",
"tensor-spec": {
"dtype": "uint16",
"shape": [-1, 64, 64, 3]
}
}
],
"outputs": [
{
"type": "tensor",
"tensor-spec": {
"dtype": "int64",
"shape": [-1]
}
}
]
}

Build

Init project

Install requirements.

pip install -r requirements.txt

Run a development environment

Install requirements.

pip install -r requirements-dev.txt

If useful, install a local version for tango-interfaces.

pip uninstall tango-interfaces
pip install -e <path>/tango-interfaces/

Run tracking server and model registry

mlflow ui

Prepare the environment for training/serving

Export the following environment variables.

export MLFLOW_TRACKING_URI="http://127.0.0.1:5000"
export MLFLOW_EXPERIMENT_NAME="test"

Use these if the server requires authentication.

export MLFLOW_TRACKING_USERNAME=<username>
export MLFLOW_TRACKING_PASSWORD=<password>

Make a training run

Before running the model, make sure to use this (for bash terminal):

export PYTHONPATH=${PWD}/src/

Or this (for fish terminal):

export PYTHONPATH={$PWD}/src/

To make a training run without loading already trained best models, execute the following command (this operation takes some time as creates environment).

python -m multimodal_xai_prostate.main

The settable parameters are "skip-training" (boolean), which allows to directly load the best model, and "num_patients" (integer) to control the number of patients for which the explanations are generated. For example:

python -m multimodal_xai_prostate.main --no-training --num-patients=10

Running local model server

Install requirements.

pip install -r requirements-dev-modelserver.txt

Run a model server

Access experiment from tracking server web UI at http://127.0.0.1:5000.

Access experiment

Copy experiment runid.

Copy run id

Use the runid to set the corresponding environment variable.

export MLFLOW_RUN_ID=<runid>

Serve the model identified by the runid.

mlflow models serve -m runs:/$MLFLOW_RUN_ID/model --enable-mlserver -p 5000

Dealing with mutex errors on Mac OS 15.6.1

To date (07/10/2025) this model is not working correctly on Mac OS 15.6.1, presumably due to a sort of incompatibility betwwen tensorflow and mlflow.

During the training phase, we are getting the following error:

libc++abi: terminating due to uncaught exception of type std::__1::system_error: mutex lock failed: Invalid argument

During Airflow execution the flow stucks after the following log:

{process_utils.py:194} INFO - [mutex.cc : 452] RAW: Lock blocking 0x600002ac54b8   @

The workaround process to train locally a model (e.g. towards Databricks) is the following.

Create (or start and access to) a dedicated Docker container based on Linux distribution.

# create
docker run -it --name xai python:3.11 /bin/bash

# start and access to
docker start xai
docker exec -it xai /bin/bash

Create a folder in it.

mkdir xai
cd xai

From Docker host, copy source code to container.

docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/src xai:/xai
docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/requirements.txt xai:/xai
docker cp ~/Work/tango/tango-library/multimodal_xai_prostate/requirements-dev.txt xai:/xai

Inside the container, install requirements and train the model.

# install requirements
pip install -r requirements.txt
pip install -r requirements-dev.txt

# set variables in the script and run it to train the model
./run_train_model_staging_databricks.sh