This article opens up
pigauto’s engine room. It is intended for readers who want
to know exactly what the package does between
impute(traits, tree, covariates) and the returned
pigauto_result. We cover the data flow, the gated ensemble
formula, the phylogenetic baseline (joint MVN / threshold-joint / OVR),
the GNN’s transformer blocks, the optional within-row cross-trait
attention added in v0.9.3, and the calibration / uncertainty
quantification machinery on top.
We assume familiarity with maximum-likelihood Brownian motion (BM) on trees and with standard transformer attention. We do not assume familiarity with the package’s internals.
impute() chains six internal stages. Tensor shapes use
\(n\) for observations (= species in
single-obs mode, > species when each species has multiple
measurements), \(p\) for the
latent trait dimension (which expands categorical traits to
\(K\) one-hot columns), \(k\) for spectral Laplacian features, \(c\) for covariate dimensions, and \(T\) for the posterior tree count in
tree-uncertainty pooling.
INPUT:
traits data.frame n × p_traits (mixed types, NAs allowed)
tree phylo m tips
covariates data.frame n × c (optional)
STEP 1 — preprocess_traits()
X_scaled n × p per-type encoding + z-score
trait_map list descriptors per trait (type, levels, mean, sd)
obs_to_species n only if n > m (multi-obs mode)
STEP 2 — build_phylo_graph()
coords m × k Laplacian eigenvectors
adj m × m Gaussian kernel on cophenetic distance
D_sq m × m squared cophenetic distance (B2)
STEP 3 — fit_baseline() [FIXED, NOT TRAINED]
mu m × p phylogenetic baseline (BM / LP / joint)
se m × p conditional-MVN standard errors
STEP 4 — fit_pigauto() trains ResidualPhyloDAE on a DAE objective
produces a torch nn_module + per-column gate
STEP 5 — calibrate_gates() picks per-trait r_cal on held-out val cells
STEP 6 — predict.pigauto_fit() applies the blend:
pred = (1 - r_cal) · mu + r_cal · delta_GNN
The two essential ideas are:
The baseline is a per-trait-type dispatcher inside
fit_baseline(). For each trait you get \(\mu_i\) (prior posterior mean for missing
cell \(i\)) and \(\sigma_i\) (its posterior SD). All later
inference is conditional on this.
For one continuous trait \(y\) with \(y_O\) observed and \(y_M\) missing, \(y \sim \mathcal{N}(\beta \mathbf{1},\, \sigma^2 \mathbf{R})\) under BM on the species tree, where \(\mathbf{R} = \mathrm{cov2cor}(\mathrm{vcv}(\text{tree}))\) is the phylogenetic correlation matrix. The conditional posterior is the closed-form GLS solution:
\[ \hat{\mu}_M = \beta \mathbf{1} + \mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} (y_O - \beta \mathbf{1}_O), \quad \widehat{\sigma}^2_{M|O} = \sigma^2\bigl(1 - \mathrm{diag}(\mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} \mathbf{R}_{OM})\bigr). \]
R/bm_internal.R::bm_impute_col() implements this
directly. Count traits are log1p-transformed, ordinal traits are coerced
to integers and z-scored, proportions are logit-z-scored. The same
conditional formula is used in each transformed space.
Discrete traits use a softer phylogenetic prior: each species’s class probability is a kernel-weighted average over the rest of the tree, \(\Pr(y_i = k) \propto \sum_{j} \mathbf{A}_{ij} \, \mathbb{1}(y_j = k)\), with \(\mathbf{A}\) the same Gaussian kernel used in the GNN’s adjacency. For categorical traits with \(K\) classes this returns \(K\) log-probability columns; binary returns one logit column.
When Rphylopars is installed and ≥ 2 continuous-like
latent columns exist, pigauto upgrades to a joint multivariate-BM
baseline. Stack the \(p\) BM-eligible
columns into \(\mathbf{Y}\). Under
joint BM, \(\mathrm{vec}(\mathbf{Y}) \sim
\mathcal{N}(\mathrm{vec}(\boldsymbol{\beta}),\,
\mathbf{\Sigma} \otimes \mathbf{R})\).
Rphylopars::phylopars() fits \(\hat{\mathbf{\Sigma}}\) jointly across
traits, and the conditional posterior is the GLS solution with the
Kronecker covariance. This captures cross-trait phylogenetic correlation
that the per-column path loses. The bench in
script/bench_joint_baseline.R showed a 33.7% RMSE lift on
simulated correlated BM data.
To bring binary traits inside the joint MVN, pigauto’s
fit_joint_threshold_baseline() uses the Wright–Falconer
liability model: each observed binary cell \(y
\in \{0, 1\}\) is replaced by the posterior mean of an underlying
continuous liability \(L\)
truncated by \(y\). With prior \(L \sim \mathcal{N}(0, 1)\) this is the
standard truncated-normal mean. The resulting matrix is then handed to
phylopars() exactly like the continuous case. Binary
posteriors are decoded back to probabilities via \(p = \Phi\bigl(\mu_L / \sqrt{1 +
\sigma_L^2}\bigr)\), clipped to \([0.01, 0.99]\). Ordinal liability uses K-1
interval cuts (B3 ordinal).
The single-fit approach to categorical liability (K columns into one phylopars call) is rank-deficient and unstable. pigauto instead runs \(K\) independent threshold-joint fits — class \(k\) vs the rest — and renormalises the resulting \(K\) probabilities into a row-stochastic distribution. This is BACE’s strategy and lifts AVONET Trophic.Level accuracy from ~42% to ~72%.
When each species has multiple observations, baselines run at species level. Phase 10 aggregates obs→species with type-aware rules: mean for continuous, modal class (or argmax one-hot) for discrete. B1 (v0.9.0) adds an opt-in soft path that preserves evidence strength: a species observed as class 1 in 6/10 rows uses the convex combination \(p \cdot E[L \mid L > 0] + (1 - p) \cdot E[L \mid L < 0]\) instead of collapsing to hard class 1.
The GNN’s job is to learn an additive correction on top of the phylogenetic baseline. It is a denoising autoencoder (DAE) trained with masked-cell reconstruction loss.
Name note. The torch class is
ResidualPhyloDAEbecause its internal blocks use ResNet-style residual skip connections. The network outputdeltais not a statistical residual \(y - \mu\) — it is a full per-cell prediction, blended externally with \(\mu\) via the per-trait gate.
Input tensors:
x \(\in \mathbb{R}^{n \times
p}\) — current trait latent matrix (missing cells replaced by a
learnable mask token).coords \(\in \mathbb{R}^{m
\times k}\) — species-level spectral Laplacian features.covs \(\in \mathbb{R}^{n
\times c}\) — [baseline_mu | NA-mask | user_covs]
(the user covariates plus a per-cell mask indicator and the baseline
prediction).If use_trait_attention = TRUE (new in v0.9.3, see §3.5),
a pooled trait-context feature of dimension trait_embed_dim
is also concatenated. The encoder is a two-layer MLP:
\[ h = \mathrm{ReLU}\bigl(\mathbf{W}_2 \, \mathrm{Dropout}(\mathrm{ReLU}(\mathbf{W}_1 \mathrm{concat}(x, \text{coords}, \text{covs}, \ldots)))\bigr), \]
producing \(h \in \mathbb{R}^{n \times h_d}\) with hidden dim \(h_d\) (default 64).
In multi-obs mode, \(h\) is averaged
across observations of the same species (scatter_mean) to
produce a species-level hidden state \(h_\text{species} \in \mathbb{R}^{m \times
h_d}\) for the graph message passing, then broadcast back to
observation level afterwards.
The default path stacks \(L\)
pre-norm transformer encoder blocks (n_gnn_layers, default
2). Each block has:
n_heads default 4) over the \(m\) species, with a per-head
learnable phylogenetic bias added to the attention scores. With
\(\mathbf{D}^2\) the squared
cophenetic-distance matrix and \(\beta_h =
\mathrm{softplus}(\mathrm{log\_bw}_h)\) a learned bandwidth per
head, the bias is \(\mathbf{B}_h =
-\mathbf{D}^2 / (2\beta_h^2)\). One head can attend tightly
(fast-evolving traits), another broadly (conserved traits).cov_h = cov_encoder(user_covs) is added inside
the block via a learnable projection, so covariate features are visible
at every depth, not just the encoder.Legacy single-head attention is retained behind
use_transformer_blocks = FALSE for reconstructing
pre-v0.9.0 fits.
A symmetric two-layer MLP maps the species-broadcast hidden state
back to the latent space: delta = dec2(ReLU(dec1(h))) of
shape \(n \times p\).
The per-column blend gate is
\[ r = \sigma(\rho) \cdot \mathtt{gate\_cap}, \qquad r \in (0, \mathtt{gate\_cap}], \]
with \(\rho \in \mathbb{R}^p\) a learnable per-column parameter. The model output is
\[ \hat{x} = (1 - r) \cdot \mu + r \cdot \mathrm{delta} + \mathrm{cov\_linear}(u), \]
where cov_linear is a small direct linear regression on
user covariates (added outside the blend; gives the GNN a “linear
shortcut” on covariates so it doesn’t have to learn \(\beta\) through nonlinear layers).
The gate is initialised so the GNN contribution starts negligible: \(\rho_\text{init} = -1\) for continuous columns (effective gate ≈ 0.135 × gate_cap), \(\rho_\text{init} \approx 0\) for discrete (fully closed).
For each training batch:
\[ \mathcal{L} = \underbrace{\mathcal{L}_\text{type}(\hat{x}, y_\text{true})}_\text{type-aware reconstruction} + \lambda_\text{shrink} \cdot \mathrm{MSE}(\mathrm{delta} - \mu) + \lambda_\text{gate} \cdot \mathrm{MSE}(r). \]
delta drifting away from the baseline.Together with the architectural cap (gate_cap ≤ 0.8 by
default), these three give a strong inductive bias toward the
phylogenetic baseline — useful when the GNN has nothing extra to
learn on a given dataset.
The encoder above mixes all trait columns into a single hidden vector
via a linear projection, which loses per-trait identity. When
use_trait_attention = TRUE, the model additionally builds a
per-trait token sequence:
\[ \text{tokens}_{i,j} = \mathbf{W}_v \, x_{i,j} + \mathbf{e}_j, \qquad j = 1, \ldots, p, \]
with \(\mathbf{e}_j \in
\mathbb{R}^{e_d}\) a learnable positional embedding per trait
column (dim trait_embed_dim, default 32). One multi-head
self-attention block (n_trait_heads, default 2) mixes these
\(p\) tokens within each row \(i\), followed by mean-pool to a single
\(e_d\)-dim feature. That feature is
concatenated alongside \((x, \text{coords},
\text{covs})\) at the encoder input.
The mechanism is intended for trait sets with strong within-row functional coupling that the joint MVN baseline cannot already capture (non-Gaussian, nonlinear, or non-monotone cross-trait structure). On the BIEN plant bench it did not improve pooled RMSE — the joint MVN/threshold-joint path already encodes \(\mathbf{\Sigma}\) for the dominant trait types, so the second cross-trait mechanism is redundant.
Independent verification (#106, code in
#116):
a 60-replicate ablation at N=2000 (5 seeds per cell, multi-trait DGP
that fires the joint MVN baseline) confirmed pigauto_OFF
beats both pigauto_ON and pigauto_ON_L0 (the
“lazy-optimiser disarmed” arm with baseline_mu masked and
lambda_shrink = 0) — z-RMSE 1.038 vs 1.056 vs 1.057;
conformal coverage stable at 0.884–0.887. Even with the lazy-optimiser
trap disarmed, the network regresses against the baseline rather than
finding new structure. See the NEWS entry for the full table.
The flag ships as opt-in because it may help on datasets where the linear \(\mathbf{\Sigma}\) assumption is too weak (e.g. functional traits with phase-transition or interaction structure), or in small-N regimes where the joint MVN baseline hasn’t yet converged on the cross-trait covariance.
Backward-compatible: saved fits without the field default to
use_trait_attention = FALSE and reconstruct
identically.
After training, pigauto does not ship the gate \(\sigma(\rho)\) that the optimiser converged to. Instead it overrides \(\rho\) with a per-trait calibrated gate \(r_\text{cal}\) chosen on held-out validation cells.
calibrate_gates() runs a per-trait grid search of \(r \in [0, \mathtt{gate\_cap}]\) minimising
val-set reconstruction loss (MSE for continuous; 0/1 loss for discrete
with an absolute cell floor and a split-validation cross-check that
prevents the GNN from harming baseline accuracy). The output is a single
scalar \(r_\text{cal}\) per latent
column, stored on the fit and used at prediction time.
This is the second layer of safety: even if training somehow pushes the learnable gate wide open, calibration on held-out cells can close it back down. In practice, on datasets with strong phylogenetic signal where the GNN cannot improve on BM, \(r_\text{cal} \approx 0\) and the prediction is exactly the baseline.
pigauto exposes three distinct uncertainty mechanisms; they answer different questions and must not be conflated.
| Mechanism | Source | Validity |
|---|---|---|
pred$se (cont./count/ordinal/prop.) |
Conditional-MVN SD from the BM baseline, delta-method back-transformed. | Exact under BM, model-dependent. |
pred$se (binary/categorical) |
\(\min(p, 1{-}p)\) / \(1 - \max_k p_k\) — uncertainty score, not a Gaussian SE. | Use for ranking/reporting; do not plug into Rubin’s rules. |
pred$conformal_lower,
pred$conformal_upper |
Split conformal residual quantile on the val set: \(s = q_{(1-\alpha)}( | y - \hat{y} |_{\text{val}} )\). | Distribution-free 95% marginal coverage regardless of model assumptions. |
For multiple-imputation workflows, multi_impute()
exposes two draw methods. "conformal" (default) samples
missing cells from \(\mathcal{N}(\mu,
s/1.96)\) on the transformed scale — calibrated against actual
residuals. "mc_dropout" runs \(M\) stochastic GNN passes in training mode
(dropout active) on top of BM-draw inputs, which is wider than conformal
but reflects model uncertainty when the gate is fully open.
For pooled inference, pool_mi() applies Rubin’s (1987)
rules to \(M\) downstream fits and
returns a tidy data.frame with estimate,
std.error, df, fmi,
riv per term.
When the species tree itself is uncertain, the Nakagawa & de
Villemereuil (2019, Syst. Biol.) algorithm pools across \(T\) posterior trees.
multi_impute_trees(traits, trees, m_per_tree = 5L) performs
step 1 — a full pigauto fit per tree, producing \(T \times M\) completed datasets, each
tagged with the tree index that produced it. Step 2 is
the user’s responsibility: refit the downstream comparative model using
the same tree that produced each dataset, then pool all
\(T \times M\) fits via
pool_mi():
fits <- Map(function(dat, t_idx) {
dat$species <- rownames(dat)
nlme::gls(y ~ x,
correlation = ape::corBrownian(phy = trees[[t_idx]],
form = ~species),
data = dat, method = "ML")
}, mi$datasets, mi$tree_index)
pool_mi(fits)Compute is linear in \(T\). The 2019 paper’s relative-efficiency index typically converges before \(T = 50\).
A clean end-to-end summary of what predict.pigauto_fit()
returns for one continuous trait on one missing cell \(i\) in a single-tree, single- imputation
call:
\[ \hat{y}_i \;=\; \underbrace{(1 - r_\text{cal}) \cdot \mu_i^{\text{BM/joint-MVN}}}_\text{baseline contribution} + \underbrace{r_\text{cal} \cdot \mathrm{delta}_i^{\text{GNN}}}_\text{GNN contribution} + \underbrace{\mathrm{cov\_linear}(u_i)}_\text{linear cov shortcut}, \]
with \(\mu_i\) from the joint MVN (or per-column BM fallback), \(\mathrm{delta}_i\) from the GNN (transformer blocks + optional within- row cross-trait attention), and \(r_\text{cal}\) from validation calibration. The conformal interval is
\[ \hat{y}_i \pm s_t, \]
where \(s_t\) is the split-conformal residual quantile for trait \(t\), giving \(\ge 95\%\) marginal coverage on the original scale.
| Concept | File |
|---|---|
| BM kernel + conditional MVN | R/bm_internal.R |
| Joint MVN baseline | R/joint_mvn_baseline.R |
| Threshold-joint (binary + ordinal) | R/joint_threshold_baseline.R,
R/liability.R |
| OVR categorical | R/ovr_categorical.R |
| Graph Transformer Block (B2) | R/graph_transformer_block.R |
| ResidualPhyloDAE + B3 trait attention | R/model_residual_dae.R |
| Training loop, gate calibration | R/fit_pigauto.R, R/fit_helpers.R |
| Prediction, conformal intervals | R/predict_pigauto.R |
| Multi-imputation pooling | R/multi_impute.R, R/pool_mi.R |
| Tree uncertainty workflow | R/multi_impute_trees.R |