LangDAug: Langevin Data Augmentation for Multi-Source Domain Generalization in Medical Image Segmentation

Indian Institute of Science, Bangalore

ICML 2025

Overview of our proposed LangDAug method.
Overview of our proposed LangDAug method.

Abstract

LangDAug is a Langevin-based data augmentation method for improving domain generalization in 2D medical image segmentation. It leverages Energy-Based Models (EBMs) trained via contrastive divergence to generate intermediate samples that bridge source domains using Langevin Dynamics. These samples act as natural augmentations, improving generalization to unseen domains.

We show that LangDAug provides a regularizing effect, theoretically bounding model complexity by the intrinsic dimensionality of the data manifold. Empirically, it outperforms state-of-the-art methods on retinal fundus and prostate MRI segmentation tasks, and complements domain-randomization strategies effectively.

Methodology

Problem Setup

Given multiple source domains {Di}i=1n\{D_i\}_{i=1}^n, each associated with a distribution PDi(x,y)P_{D_i}(x, y) over the input-output space X×Y\mathcal{X} \times \mathcal{Y}, the standard Empirical Risk Minimization (ERM) objective seeks to find model parameters θ^\hat{\theta} that minimize the average loss over all training samples:

θ^=argminθ1Ni=1N(fθ(xi),yi)\hat{\theta} = \arg\min_{\theta} \frac{1}{N} \sum_{i=1}^{N} \ell(f_\theta(x_i), y_i)

where N=i=1nDiN = \sum_{i=1}^n |D_i| denotes the total number of training samples aggregated from all source domains. However, ERM often fails to generalize to an unseen target domain Dn+1{Di}i=1nD_{n+1} \notin \{D_i\}_{i=1}^n, as it only optimizes performance over the observed source domains.

Inter-Domain Traversal with EBMs

To bridge domains, we train an EBM EθijE_{\theta_{ij}} to model the energy between domain pairs (Di,Dj)(D_i, D_j). The model is trained using Contrastive Divergence:

θijLCD=ExPDj[θijEθij(x)]ExPθij[θijEθij(x)]\nabla_{\theta_{ij}} \mathcal{L}_{CD} = \mathbb{E}_{x \sim P_{D_j}}[\nabla_{\theta_{ij}} E_{\theta_{ij}}(x)] - \mathbb{E}_{x \sim P_{\theta_{ij}}}[\nabla_{\theta_{ij}} E_{\theta_{ij}}(x)]
Pθij=exp(Eθij(x))Zθij,where Zθij=Xexp(Eθij(x))dxP_{\theta_{ij}} = \frac{\exp(-E_{\theta_{ij}}(x))}{Z_{\theta_{ij}}}, \quad \text{where } Z_{\theta_{ij}} = \int_{\mathcal{X}} \exp(-E_{\theta_{ij}}(x)) \, dx

Sampling from PθijP_{\theta_{ij}} is done using Langevin Dynamics with the chain being initialized at a point x0PDix_0 \sim P_{D_i}:

xt+1=xtα22xEθij(xt)+αϵ,ϵN(0,I)x_{t+1} = x_t - \frac{\alpha^2}{2} \nabla_x E_{\theta_{ij}}(x_t) + \alpha \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)

These LD iterates form samples that interpolate between domains.

Langevin Data Augmentation

We use the intermediate LD samples as augmentation data. For each sample xjDix_j \in D_i, LD is run for KK steps to generate {xjt}t=1K\{x_j^t\}_{t=1}^K:

xjt+1=xjtβ22xEθij(xjt)+βϵx_{j}^{t+1} = x_j^t - \frac{\beta^2}{2} \nabla_x E_{\theta_{ij}}(x_j^t) + \beta \epsilon

These samples, combined with original labels yjy_j, are used in ERM training, effectively expanding the domain support:

Daug=ij,kDijk,where Dijk={(xjk,yj)}\mathcal{D}_{\text{aug}} = \bigcup_{i \neq j, k} D_{ij}^k, \quad \text{where } D_{ij}^k = \{(x_j^k, y_j)\}

Theoretical Insights

LangDAug acts as a regularizer on the ERM objective. Let x~i\tilde{x}_i be the Langevin-perturbed sample. Then the augmented empirical risk is:

Laug(θ,D)=1ki=1kEϵN(0,I)[(θ,z~i)]\mathcal{L}_{\text{aug}}(\theta, \mathcal{D}) = \frac{1}{k} \sum_{i=1}^k \mathbb{E}_{\epsilon \sim \mathcal{N}(0, I)} [\ell(\theta, \tilde{z}_i)]

This can be decomposed as:

Laug=Lstd+R1+R2+R3\mathcal{L}_{\text{aug}} = \mathcal{L}_{\text{std}} + R_1 + R_2 + R_3

Where R1,R2,R3R_1, R_2, R_3 are regularization terms involving the first and second derivatives of the model fθf_\theta, encouraging smoother and flatter solutions.

Cross-domain generalization performance

Retinal Fundus Segmentation

MethodDomain ADomain BDomain CDomain DAvg mIoUAvg mDSC
Hutchinson66.7366.7369.3666.7367.3978.14
MixStyle80.7667.6979.7977.0976.3385.58
FedDG76.6572.1476.1075.9675.2183.67
RAM77.4273.7979.6678.7477.4085.39
TriD80.9272.4579.3478.9677.9285.95
LangDAug (Ours)78.7975.0581.0180.5178.8487.61

Prostate MRI Segmentation

MethodDomain ADomain BDomain CDomain DDomain EDomain FAvg ASDAvg DSC
Hutchinson3.281.482.073.982.781.642.5478.62
MixStyle0.720.881.620.651.590.511.0086.27
FedDG1.090.931.310.881.730.501.0785.95
RAM0.930.981.260.741.780.321.0087.02
TriD0.700.721.390.711.430.460.9087.68
LangDAug (Ours)0.580.641.210.571.490.380.8189.16

Inter-Domain Traversal Examples

Retinal Fundus Dataset

Source domain A

Domain A to Domain D

Source domain B

Domain B to Domain D

Source domain C

Domain C to Domain B

Source domain D

Domain D to Domain A

Prostate MRI Dataset

Source domain A

Domain A to Domain B

Source domain B

Domain B to Domain F

Source domain C

Domain C to Domain A

Source domain D

Domain D to Domain F

Source domain E

Domain E to Domain B

Source domain F

Domain F to Domain A

BibTeX citation

If you find this work useful, please cite:

@inproceedings{tiwary2025langdaug,
  title={LangDAug: Langevin Data Augmentation for Multi-Source Domain Generalization in Medical Image Segmentation},
  author={Tiwary, Piyush and Bhattacharyya, Kinjawl and Prathosh, A.P.},
  booktitle={Proceedings of the 42nd International Conference on Machine Learning},
  year={2025}
}