GNN architecture and the math behind pigauto

This article opens up pigauto’s engine room. It is intended for readers who want to know exactly what the package does between impute(traits, tree, covariates) and the returned pigauto_result. We cover the data flow, the gated ensemble formula, the phylogenetic baseline (joint MVN / threshold-joint / OVR), the GNN’s transformer blocks, the optional within-row cross-trait attention added in v0.9.3, and the calibration / uncertainty quantification machinery on top.

We assume familiarity with maximum-likelihood Brownian motion (BM) on trees and with standard transformer attention. We do not assume familiarity with the package’s internals.

1. End-to-end data flow

impute() chains six internal stages. Tensor shapes use \(n\) for observations (= species in single-obs mode, > species when each species has multiple measurements), \(p\) for the latent trait dimension (which expands categorical traits to \(K\) one-hot columns), \(k\) for spectral Laplacian features, \(c\) for covariate dimensions, and \(T\) for the posterior tree count in tree-uncertainty pooling.

INPUT:
  traits      data.frame   n × p_traits     (mixed types, NAs allowed)
  tree        phylo        m tips
  covariates  data.frame   n × c            (optional)

STEP 1 — preprocess_traits()
  X_scaled       n × p           per-type encoding + z-score
  trait_map      list            descriptors per trait (type, levels, mean, sd)
  obs_to_species n               only if n > m (multi-obs mode)

STEP 2 — build_phylo_graph()
  coords  m × k                  Laplacian eigenvectors
  adj     m × m                  Gaussian kernel on cophenetic distance
  D_sq    m × m                  squared cophenetic distance (B2)

STEP 3 — fit_baseline()  [FIXED, NOT TRAINED]
  mu  m × p                      phylogenetic baseline (BM / LP / joint)
  se  m × p                      conditional-MVN standard errors

STEP 4 — fit_pigauto() trains ResidualPhyloDAE on a DAE objective
  produces a torch nn_module + per-column gate

STEP 5 — calibrate_gates() picks per-trait r_cal on held-out val cells

STEP 6 — predict.pigauto_fit() applies the blend:
  pred = (1 - r_cal) · mu + r_cal · delta_GNN

The two essential ideas are:

  1. A simple, well-understood phylogenetic baseline does most of the work, and the GNN only contributes when calibrated to do so.
  2. The blend gate \(r_\text{cal}\) is a safety floor: at \(r_\text{cal}=0\) the prediction collapses to the baseline. This bounds the worst-case regression to “as good as the baseline” regardless of how poorly the GNN trains.

2. The phylogenetic baseline

The baseline is a per-trait-type dispatcher inside fit_baseline(). For each trait you get \(\mu_i\) (prior posterior mean for missing cell \(i\)) and \(\sigma_i\) (its posterior SD). All later inference is conditional on this.

2.1 Continuous, count, ordinal, proportion: Brownian motion

For one continuous trait \(y\) with \(y_O\) observed and \(y_M\) missing, \(y \sim \mathcal{N}(\beta \mathbf{1},\, \sigma^2 \mathbf{R})\) under BM on the species tree, where \(\mathbf{R} = \mathrm{cov2cor}(\mathrm{vcv}(\text{tree}))\) is the phylogenetic correlation matrix. The conditional posterior is the closed-form GLS solution:

\[ \hat{\mu}_M = \beta \mathbf{1} + \mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} (y_O - \beta \mathbf{1}_O), \quad \widehat{\sigma}^2_{M|O} = \sigma^2\bigl(1 - \mathrm{diag}(\mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} \mathbf{R}_{OM})\bigr). \]

R/bm_internal.R::bm_impute_col() implements this directly. Count traits are log1p-transformed, ordinal traits are coerced to integers and z-scored, proportions are logit-z-scored. The same conditional formula is used in each transformed space.

2.2 Binary and categorical: phylogenetic label propagation (LP)

Discrete traits use a softer phylogenetic prior: each species’s class probability is a kernel-weighted average over the rest of the tree, \(\Pr(y_i = k) \propto \sum_{j} \mathbf{A}_{ij} \, \mathbb{1}(y_j = k)\), with \(\mathbf{A}\) the same Gaussian kernel used in the GNN’s adjacency. For categorical traits with \(K\) classes this returns \(K\) log-probability columns; binary returns one logit column.

2.3 Joint MVN baseline (Phase 2)

When Rphylopars is installed and ≥ 2 continuous-like latent columns exist, pigauto upgrades to a joint multivariate-BM baseline. Stack the \(p\) BM-eligible columns into \(\mathbf{Y}\). Under joint BM, \(\mathrm{vec}(\mathbf{Y}) \sim \mathcal{N}(\mathrm{vec}(\boldsymbol{\beta}),\, \mathbf{\Sigma} \otimes \mathbf{R})\). Rphylopars::phylopars() fits \(\hat{\mathbf{\Sigma}}\) jointly across traits, and the conditional posterior is the GLS solution with the Kronecker covariance. This captures cross-trait phylogenetic correlation that the per-column path loses. The bench in script/bench_joint_baseline.R showed a 33.7% RMSE lift on simulated correlated BM data.

2.4 Threshold-joint baseline (Phase 3)

To bring binary traits inside the joint MVN, pigauto’s fit_joint_threshold_baseline() uses the Wright–Falconer liability model: each observed binary cell \(y \in \{0, 1\}\) is replaced by the posterior mean of an underlying continuous liability \(L\) truncated by \(y\). With prior \(L \sim \mathcal{N}(0, 1)\) this is the standard truncated-normal mean. The resulting matrix is then handed to phylopars() exactly like the continuous case. Binary posteriors are decoded back to probabilities via \(p = \Phi\bigl(\mu_L / \sqrt{1 + \sigma_L^2}\bigr)\), clipped to \([0.01, 0.99]\). Ordinal liability uses K-1 interval cuts (B3 ordinal).

2.5 OVR categorical (Phase 6)

The single-fit approach to categorical liability (K columns into one phylopars call) is rank-deficient and unstable. pigauto instead runs \(K\) independent threshold-joint fits — class \(k\) vs the rest — and renormalises the resulting \(K\) probabilities into a row-stochastic distribution. This is BACE’s strategy and lifts AVONET Trophic.Level accuracy from ~42% to ~72%.

2.6 Multi-obs aggregation (Phase 10 + B1 soft)

When each species has multiple observations, baselines run at species level. Phase 10 aggregates obs→species with type-aware rules: mean for continuous, modal class (or argmax one-hot) for discrete. B1 (v0.9.0) adds an opt-in soft path that preserves evidence strength: a species observed as class 1 in 6/10 rows uses the convex combination \(p \cdot E[L \mid L > 0] + (1 - p) \cdot E[L \mid L < 0]\) instead of collapsing to hard class 1.

3. The GNN: ResidualPhyloDAE

The GNN’s job is to learn an additive correction on top of the phylogenetic baseline. It is a denoising autoencoder (DAE) trained with masked-cell reconstruction loss.

Name note. The torch class is ResidualPhyloDAE because its internal blocks use ResNet-style residual skip connections. The network output delta is not a statistical residual \(y - \mu\) — it is a full per-cell prediction, blended externally with \(\mu\) via the per-trait gate.

3.1 Encoder

Input tensors:

  • x \(\in \mathbb{R}^{n \times p}\) — current trait latent matrix (missing cells replaced by a learnable mask token).
  • coords \(\in \mathbb{R}^{m \times k}\) — species-level spectral Laplacian features.
  • covs \(\in \mathbb{R}^{n \times c}\)[baseline_mu | NA-mask | user_covs] (the user covariates plus a per-cell mask indicator and the baseline prediction).

If use_trait_attention = TRUE (new in v0.9.3, see §3.5), a pooled trait-context feature of dimension trait_embed_dim is also concatenated. The encoder is a two-layer MLP:

\[ h = \mathrm{ReLU}\bigl(\mathbf{W}_2 \, \mathrm{Dropout}(\mathrm{ReLU}(\mathbf{W}_1 \mathrm{concat}(x, \text{coords}, \text{covs}, \ldots)))\bigr), \]

producing \(h \in \mathbb{R}^{n \times h_d}\) with hidden dim \(h_d\) (default 64).

In multi-obs mode, \(h\) is averaged across observations of the same species (scatter_mean) to produce a species-level hidden state \(h_\text{species} \in \mathbb{R}^{m \times h_d}\) for the graph message passing, then broadcast back to observation level afterwards.

3.2 Graph Transformer Block (Phase 9 + B2)

The default path stacks \(L\) pre-norm transformer encoder blocks (n_gnn_layers, default 2). Each block has:

  • Multi-head attention (n_heads default 4) over the \(m\) species, with a per-head learnable phylogenetic bias added to the attention scores. With \(\mathbf{D}^2\) the squared cophenetic-distance matrix and \(\beta_h = \mathrm{softplus}(\mathrm{log\_bw}_h)\) a learned bandwidth per head, the bias is \(\mathbf{B}_h = -\mathbf{D}^2 / (2\beta_h^2)\). One head can attend tightly (fast-evolving traits), another broadly (conserved traits).
  • A position-wise FFN with width \(h_d \cdot \mathtt{ffn\_mult}\) (default 4), output linear initialised to zero so the block ≈ identity at training step 0 — preserves the gate-closed-at-init safety.
  • Layer norm before each sub-block, residual skips after each.
  • Optional per-layer covariate injection (when user covariates are present): cov_h = cov_encoder(user_covs) is added inside the block via a learnable projection, so covariate features are visible at every depth, not just the encoder.

Legacy single-head attention is retained behind use_transformer_blocks = FALSE for reconstructing pre-v0.9.0 fits.

3.3 Decoder and the gate

A symmetric two-layer MLP maps the species-broadcast hidden state back to the latent space: delta = dec2(ReLU(dec1(h))) of shape \(n \times p\).

The per-column blend gate is

\[ r = \sigma(\rho) \cdot \mathtt{gate\_cap}, \qquad r \in (0, \mathtt{gate\_cap}], \]

with \(\rho \in \mathbb{R}^p\) a learnable per-column parameter. The model output is

\[ \hat{x} = (1 - r) \cdot \mu + r \cdot \mathrm{delta} + \mathrm{cov\_linear}(u), \]

where cov_linear is a small direct linear regression on user covariates (added outside the blend; gives the GNN a “linear shortcut” on covariates so it doesn’t have to learn \(\beta\) through nonlinear layers).

The gate is initialised so the GNN contribution starts negligible: \(\rho_\text{init} = -1\) for continuous columns (effective gate ≈ 0.135 × gate_cap), \(\rho_\text{init} \approx 0\) for discrete (fully closed).

3.4 Loss and three safety regularisations

For each training batch:

\[ \mathcal{L} = \underbrace{\mathcal{L}_\text{type}(\hat{x}, y_\text{true})}_\text{type-aware reconstruction} + \lambda_\text{shrink} \cdot \mathrm{MSE}(\mathrm{delta} - \mu) + \lambda_\text{gate} \cdot \mathrm{MSE}(r). \]

  • \(\mathcal{L}_\text{type}\) dispatches by trait type: MSE for continuous / count / ordinal / proportion, BCE for binary / zi-gate, cross-entropy for categorical, MSE on CLR for multi-proportion.
  • \(\lambda_\text{shrink}\) (default 0.03) penalises delta drifting away from the baseline.
  • \(\lambda_\text{gate}\) (default 0.01) actively pushes the gate toward zero. Without this term, \(\rho\) has no gradient when delta ≈ μ (both shrinkage and reconstruction losses are zero), so the gate would stay at its init. The explicit penalty guarantees the gate defaults toward baseline-only.

Together with the architectural cap (gate_cap ≤ 0.8 by default), these three give a strong inductive bias toward the phylogenetic baseline — useful when the GNN has nothing extra to learn on a given dataset.

3.5 Within-row cross-trait attention (B3, v0.9.3, opt-in)

The encoder above mixes all trait columns into a single hidden vector via a linear projection, which loses per-trait identity. When use_trait_attention = TRUE, the model additionally builds a per-trait token sequence:

\[ \text{tokens}_{i,j} = \mathbf{W}_v \, x_{i,j} + \mathbf{e}_j, \qquad j = 1, \ldots, p, \]

with \(\mathbf{e}_j \in \mathbb{R}^{e_d}\) a learnable positional embedding per trait column (dim trait_embed_dim, default 32). One multi-head self-attention block (n_trait_heads, default 2) mixes these \(p\) tokens within each row \(i\), followed by mean-pool to a single \(e_d\)-dim feature. That feature is concatenated alongside \((x, \text{coords}, \text{covs})\) at the encoder input.

The mechanism is intended for trait sets with strong within-row functional coupling that the joint MVN baseline cannot already capture (non-Gaussian, nonlinear, or non-monotone cross-trait structure). On the BIEN plant bench it did not improve pooled RMSE — the joint MVN/threshold-joint path already encodes \(\mathbf{\Sigma}\) for the dominant trait types, so the second cross-trait mechanism is redundant.

Independent verification (#106, code in #116): a 60-replicate ablation at N=2000 (5 seeds per cell, multi-trait DGP that fires the joint MVN baseline) confirmed pigauto_OFF beats both pigauto_ON and pigauto_ON_L0 (the “lazy-optimiser disarmed” arm with baseline_mu masked and lambda_shrink = 0) — z-RMSE 1.038 vs 1.056 vs 1.057; conformal coverage stable at 0.884–0.887. Even with the lazy-optimiser trap disarmed, the network regresses against the baseline rather than finding new structure. See the NEWS entry for the full table.

The flag ships as opt-in because it may help on datasets where the linear \(\mathbf{\Sigma}\) assumption is too weak (e.g. functional traits with phase-transition or interaction structure), or in small-N regimes where the joint MVN baseline hasn’t yet converged on the cross-trait covariance.

Backward-compatible: saved fits without the field default to use_trait_attention = FALSE and reconstruct identically.

4. Calibration: making the gate a real safety floor

After training, pigauto does not ship the gate \(\sigma(\rho)\) that the optimiser converged to. Instead it overrides \(\rho\) with a per-trait calibrated gate \(r_\text{cal}\) chosen on held-out validation cells.

calibrate_gates() runs a per-trait grid search of \(r \in [0, \mathtt{gate\_cap}]\) minimising val-set reconstruction loss (MSE for continuous; 0/1 loss for discrete with an absolute cell floor and a split-validation cross-check that prevents the GNN from harming baseline accuracy). The output is a single scalar \(r_\text{cal}\) per latent column, stored on the fit and used at prediction time.

This is the second layer of safety: even if training somehow pushes the learnable gate wide open, calibration on held-out cells can close it back down. In practice, on datasets with strong phylogenetic signal where the GNN cannot improve on BM, \(r_\text{cal} \approx 0\) and the prediction is exactly the baseline.

5. Uncertainty quantification

pigauto exposes three distinct uncertainty mechanisms; they answer different questions and must not be conflated.

Mechanism Source Validity
pred$se (cont./count/ordinal/prop.) Conditional-MVN SD from the BM baseline, delta-method back-transformed. Exact under BM, model-dependent.
pred$se (binary/categorical) \(\min(p, 1{-}p)\) / \(1 - \max_k p_k\) — uncertainty score, not a Gaussian SE. Use for ranking/reporting; do not plug into Rubin’s rules.
pred$conformal_lower, pred$conformal_upper Split conformal residual quantile on the val set: \(s = q_{(1-\alpha)}( | y - \hat{y} |_{\text{val}} )\). Distribution-free 95% marginal coverage regardless of model assumptions.

For multiple-imputation workflows, multi_impute() exposes two draw methods. "conformal" (default) samples missing cells from \(\mathcal{N}(\mu, s/1.96)\) on the transformed scale — calibrated against actual residuals. "mc_dropout" runs \(M\) stochastic GNN passes in training mode (dropout active) on top of BM-draw inputs, which is wider than conformal but reflects model uncertainty when the gate is fully open.

For pooled inference, pool_mi() applies Rubin’s (1987) rules to \(M\) downstream fits and returns a tidy data.frame with estimate, std.error, df, fmi, riv per term.

6. Tree uncertainty: a two-step workflow

When the species tree itself is uncertain, the Nakagawa & de Villemereuil (2019, Syst. Biol.) algorithm pools across \(T\) posterior trees. multi_impute_trees(traits, trees, m_per_tree = 5L) performs step 1 — a full pigauto fit per tree, producing \(T \times M\) completed datasets, each tagged with the tree index that produced it. Step 2 is the user’s responsibility: refit the downstream comparative model using the same tree that produced each dataset, then pool all \(T \times M\) fits via pool_mi():

fits <- Map(function(dat, t_idx) {
  dat$species <- rownames(dat)
  nlme::gls(y ~ x,
            correlation = ape::corBrownian(phy = trees[[t_idx]],
                                            form = ~species),
            data = dat, method = "ML")
}, mi$datasets, mi$tree_index)
pool_mi(fits)

Compute is linear in \(T\). The 2019 paper’s relative-efficiency index typically converges before \(T = 50\).

7. Putting it together: a worked predictive equation

A clean end-to-end summary of what predict.pigauto_fit() returns for one continuous trait on one missing cell \(i\) in a single-tree, single- imputation call:

\[ \hat{y}_i \;=\; \underbrace{(1 - r_\text{cal}) \cdot \mu_i^{\text{BM/joint-MVN}}}_\text{baseline contribution} + \underbrace{r_\text{cal} \cdot \mathrm{delta}_i^{\text{GNN}}}_\text{GNN contribution} + \underbrace{\mathrm{cov\_linear}(u_i)}_\text{linear cov shortcut}, \]

with \(\mu_i\) from the joint MVN (or per-column BM fallback), \(\mathrm{delta}_i\) from the GNN (transformer blocks + optional within- row cross-trait attention), and \(r_\text{cal}\) from validation calibration. The conformal interval is

\[ \hat{y}_i \pm s_t, \]

where \(s_t\) is the split-conformal residual quantile for trait \(t\), giving \(\ge 95\%\) marginal coverage on the original scale.

8. Where to look in the code

Concept File
BM kernel + conditional MVN R/bm_internal.R
Joint MVN baseline R/joint_mvn_baseline.R
Threshold-joint (binary + ordinal) R/joint_threshold_baseline.R, R/liability.R
OVR categorical R/ovr_categorical.R
Graph Transformer Block (B2) R/graph_transformer_block.R
ResidualPhyloDAE + B3 trait attention R/model_residual_dae.R
Training loop, gate calibration R/fit_pigauto.R, R/fit_helpers.R
Prediction, conformal intervals R/predict_pigauto.R
Multi-imputation pooling R/multi_impute.R, R/pool_mi.R
Tree uncertainty workflow R/multi_impute_trees.R

References

  • Felsenstein, J. (1985). Phylogenies and the comparative method. AmNat.
  • Pagel, M. (1999). Inferring the historical patterns of biological evolution. Nature.
  • Bruggeman, J., Heringa, J., & Brandt, B. W. (2009). PhyloPars: estimation of missing parameter values using phylogeny. NAR.
  • Goolsby, E. W., Bruggeman, J., & Ané, C. (2017). Rphylopars: fast multivariate phylogenetic comparative methods for missing data and within-species variation. MEE.
  • Wright, S. (1934). An analysis of variability in number of digits in an inbred strain of guinea pigs. Genetics. (Liability model)
  • Vaswani, A. et al. (2017). Attention is all you need. NeurIPS.
  • Ying, C. et al. (2021). Do Transformers Really Perform Bad for Graph Representation? NeurIPS. (Graphormer / multi-scale phylogenetic attention bias.)
  • Rubin, D. B. (1987). Multiple Imputation for Nonresponse in Surveys.
  • Nakagawa, S., & Freckleton, R. P. (2008, 2011). Missing inaction: the dangers of ignoring missing data. TREE / Model averaging, missing data and multiple imputation. BES.
  • Nakagawa, S., & de Villemereuil, P. (2019). A general method for simultaneously accounting for phylogenetic and species sampling uncertainty via Rubin’s rules in comparative analysis. Syst. Biol. 68(4): 632–641.
  • Vovk, V., Gammerman, A., & Shafer, G. (2005). Algorithmic Learning in a Random World. (Conformal prediction.)