--- title: "GNN architecture and the math behind pigauto" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{GNN architecture and the math behind pigauto} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = FALSE ) ``` This article opens up `pigauto`'s engine room. It is intended for readers who want to know *exactly* what the package does between `impute(traits, tree, covariates)` and the returned `pigauto_result`. We cover the data flow, the gated ensemble formula, the phylogenetic baseline (joint MVN / threshold-joint / OVR), the GNN's transformer blocks, the optional within-row cross-trait attention added in v0.9.3, and the calibration / uncertainty quantification machinery on top. We assume familiarity with maximum-likelihood Brownian motion (BM) on trees and with standard transformer attention. We do **not** assume familiarity with the package's internals. ## 1. End-to-end data flow `impute()` chains six internal stages. Tensor shapes use $n$ for observations (= species in single-obs mode, > species when each species has multiple measurements), $p$ for the *latent* trait dimension (which expands categorical traits to $K$ one-hot columns), $k$ for spectral Laplacian features, $c$ for covariate dimensions, and $T$ for the posterior tree count in tree-uncertainty pooling. ``` INPUT: traits data.frame n × p_traits (mixed types, NAs allowed) tree phylo m tips covariates data.frame n × c (optional) STEP 1 — preprocess_traits() X_scaled n × p per-type encoding + z-score trait_map list descriptors per trait (type, levels, mean, sd) obs_to_species n only if n > m (multi-obs mode) STEP 2 — build_phylo_graph() coords m × k Laplacian eigenvectors adj m × m Gaussian kernel on cophenetic distance D_sq m × m squared cophenetic distance (B2) STEP 3 — fit_baseline() [FIXED, NOT TRAINED] mu m × p phylogenetic baseline (BM / LP / joint) se m × p conditional-MVN standard errors STEP 4 — fit_pigauto() trains ResidualPhyloDAE on a DAE objective produces a torch nn_module + per-column gate STEP 5 — calibrate_gates() picks per-trait r_cal on held-out val cells STEP 6 — predict.pigauto_fit() applies the blend: pred = (1 - r_cal) · mu + r_cal · delta_GNN ``` The two essential ideas are: 1. **A simple, well-understood phylogenetic baseline does most of the work**, and the GNN only contributes when calibrated to do so. 2. **The blend gate $r_\text{cal}$ is a safety floor**: at $r_\text{cal}=0$ the prediction collapses to the baseline. This bounds the worst-case regression to "as good as the baseline" regardless of how poorly the GNN trains. ## 2. The phylogenetic baseline The baseline is a per-trait-type dispatcher inside `fit_baseline()`. For each trait you get $\mu_i$ (prior posterior mean for missing cell $i$) and $\sigma_i$ (its posterior SD). All later inference is conditional on this. ### 2.1 Continuous, count, ordinal, proportion: Brownian motion For one continuous trait $y$ with $y_O$ observed and $y_M$ missing, $y \sim \mathcal{N}(\beta \mathbf{1},\, \sigma^2 \mathbf{R})$ under BM on the species tree, where $\mathbf{R} = \mathrm{cov2cor}(\mathrm{vcv}(\text{tree}))$ is the phylogenetic correlation matrix. The conditional posterior is the closed-form GLS solution: $$ \hat{\mu}_M = \beta \mathbf{1} + \mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} (y_O - \beta \mathbf{1}_O), \quad \widehat{\sigma}^2_{M|O} = \sigma^2\bigl(1 - \mathrm{diag}(\mathbf{R}_{MO} \mathbf{R}_{OO}^{-1} \mathbf{R}_{OM})\bigr). $$ `R/bm_internal.R::bm_impute_col()` implements this directly. Count traits are log1p-transformed, ordinal traits are coerced to integers and z-scored, proportions are logit-z-scored. The same conditional formula is used in each transformed space. ### 2.2 Binary and categorical: phylogenetic label propagation (LP) Discrete traits use a softer phylogenetic prior: each species's class probability is a kernel-weighted average over the rest of the tree, $\Pr(y_i = k) \propto \sum_{j} \mathbf{A}_{ij} \, \mathbb{1}(y_j = k)$, with $\mathbf{A}$ the same Gaussian kernel used in the GNN's adjacency. For categorical traits with $K$ classes this returns $K$ log-probability columns; binary returns one logit column. ### 2.3 Joint MVN baseline (Phase 2) When `Rphylopars` is installed and ≥ 2 continuous-like latent columns exist, pigauto upgrades to a joint multivariate-BM baseline. Stack the $p$ BM-eligible columns into $\mathbf{Y}$. Under joint BM, $\mathrm{vec}(\mathbf{Y}) \sim \mathcal{N}(\mathrm{vec}(\boldsymbol{\beta}),\, \mathbf{\Sigma} \otimes \mathbf{R})$. `Rphylopars::phylopars()` fits $\hat{\mathbf{\Sigma}}$ jointly across traits, and the conditional posterior is the GLS solution with the Kronecker covariance. This captures cross-trait phylogenetic correlation that the per-column path loses. The bench in `script/bench_joint_baseline.R` showed a 33.7% RMSE lift on simulated correlated BM data. ### 2.4 Threshold-joint baseline (Phase 3) To bring binary traits inside the joint MVN, pigauto's `fit_joint_threshold_baseline()` uses the Wright–Falconer liability model: each observed binary cell $y \in \{0, 1\}$ is replaced by the posterior mean of an underlying continuous *liability* $L$ truncated by $y$. With prior $L \sim \mathcal{N}(0, 1)$ this is the standard truncated-normal mean. The resulting matrix is then handed to `phylopars()` exactly like the continuous case. Binary posteriors are decoded back to probabilities via $p = \Phi\bigl(\mu_L / \sqrt{1 + \sigma_L^2}\bigr)$, clipped to $[0.01, 0.99]$. Ordinal liability uses K-1 interval cuts (B3 ordinal). ### 2.5 OVR categorical (Phase 6) The single-fit approach to categorical liability (K columns into one phylopars call) is rank-deficient and unstable. pigauto instead runs $K$ independent threshold-joint fits — class $k$ vs the rest — and renormalises the resulting $K$ probabilities into a row-stochastic distribution. This is BACE's strategy and lifts AVONET Trophic.Level accuracy from ~42% to ~72%. ### 2.6 Multi-obs aggregation (Phase 10 + B1 soft) When each species has multiple observations, baselines run at species level. Phase 10 aggregates obs→species with type-aware rules: mean for continuous, modal class (or argmax one-hot) for discrete. B1 (v0.9.0) adds an opt-in soft path that preserves evidence strength: a species observed as class 1 in 6/10 rows uses the convex combination $p \cdot E[L \mid L > 0] + (1 - p) \cdot E[L \mid L < 0]$ instead of collapsing to hard class 1. ## 3. The GNN: ResidualPhyloDAE The GNN's job is to learn an *additive correction* on top of the phylogenetic baseline. It is a denoising autoencoder (DAE) trained with masked-cell reconstruction loss. > **Name note.** The torch class is `ResidualPhyloDAE` because its > internal blocks use ResNet-style residual skip connections. The > network output `delta` is **not** a statistical residual > $y - \mu$ — it is a full per-cell prediction, blended externally with > $\mu$ via the per-trait gate. ### 3.1 Encoder Input tensors: * `x` $\in \mathbb{R}^{n \times p}$ — current trait latent matrix (missing cells replaced by a learnable mask token). * `coords` $\in \mathbb{R}^{m \times k}$ — species-level spectral Laplacian features. * `covs` $\in \mathbb{R}^{n \times c}$ — `[baseline_mu | NA-mask | user_covs]` (the user covariates plus a per-cell mask indicator and the baseline prediction). If `use_trait_attention = TRUE` (new in v0.9.3, see §3.5), a pooled trait-context feature of dimension `trait_embed_dim` is also concatenated. The encoder is a two-layer MLP: $$ h = \mathrm{ReLU}\bigl(\mathbf{W}_2 \, \mathrm{Dropout}(\mathrm{ReLU}(\mathbf{W}_1 \mathrm{concat}(x, \text{coords}, \text{covs}, \ldots)))\bigr), $$ producing $h \in \mathbb{R}^{n \times h_d}$ with hidden dim $h_d$ (default 64). In multi-obs mode, $h$ is averaged across observations of the same species (`scatter_mean`) to produce a species-level hidden state $h_\text{species} \in \mathbb{R}^{m \times h_d}$ for the graph message passing, then broadcast back to observation level afterwards. ### 3.2 Graph Transformer Block (Phase 9 + B2) The default path stacks $L$ pre-norm transformer encoder blocks (`n_gnn_layers`, default 2). Each block has: * Multi-head attention (`n_heads` default 4) over the $m$ species, with a **per-head learnable phylogenetic bias** added to the attention scores. With $\mathbf{D}^2$ the squared cophenetic-distance matrix and $\beta_h = \mathrm{softplus}(\mathrm{log\_bw}_h)$ a learned bandwidth per head, the bias is $\mathbf{B}_h = -\mathbf{D}^2 / (2\beta_h^2)$. One head can attend tightly (fast-evolving traits), another broadly (conserved traits). * A position-wise FFN with width $h_d \cdot \mathtt{ffn\_mult}$ (default 4), output linear initialised to zero so the block ≈ identity at training step 0 — preserves the gate-closed-at-init safety. * Layer norm before each sub-block, residual skips after each. * Optional per-layer covariate injection (when user covariates are present): `cov_h = cov_encoder(user_covs)` is added inside the block via a learnable projection, so covariate features are visible at every depth, not just the encoder. Legacy single-head attention is retained behind `use_transformer_blocks = FALSE` for reconstructing pre-v0.9.0 fits. ### 3.3 Decoder and the gate A symmetric two-layer MLP maps the species-broadcast hidden state back to the latent space: `delta = dec2(ReLU(dec1(h)))` of shape $n \times p$. The per-column **blend gate** is $$ r = \sigma(\rho) \cdot \mathtt{gate\_cap}, \qquad r \in (0, \mathtt{gate\_cap}], $$ with $\rho \in \mathbb{R}^p$ a learnable per-column parameter. The model output is $$ \hat{x} = (1 - r) \cdot \mu + r \cdot \mathrm{delta} + \mathrm{cov\_linear}(u), $$ where `cov_linear` is a small direct linear regression on user covariates (added outside the blend; gives the GNN a "linear shortcut" on covariates so it doesn't have to learn $\beta$ through nonlinear layers). The gate is initialised so the GNN contribution starts negligible: $\rho_\text{init} = -1$ for continuous columns (effective gate ≈ 0.135 × gate_cap), $\rho_\text{init} \approx 0$ for discrete (fully closed). ### 3.4 Loss and three safety regularisations For each training batch: $$ \mathcal{L} = \underbrace{\mathcal{L}_\text{type}(\hat{x}, y_\text{true})}_\text{type-aware reconstruction} + \lambda_\text{shrink} \cdot \mathrm{MSE}(\mathrm{delta} - \mu) + \lambda_\text{gate} \cdot \mathrm{MSE}(r). $$ * $\mathcal{L}_\text{type}$ dispatches by trait type: MSE for continuous / count / ordinal / proportion, BCE for binary / zi-gate, cross-entropy for categorical, MSE on CLR for multi-proportion. * $\lambda_\text{shrink}$ (default 0.03) penalises `delta` drifting away from the baseline. * $\lambda_\text{gate}$ (default 0.01) actively pushes the gate toward zero. Without this term, $\rho$ has no gradient when delta ≈ μ (both shrinkage and reconstruction losses are zero), so the gate would stay at its init. The explicit penalty guarantees the gate defaults toward baseline-only. Together with the architectural cap (`gate_cap` ≤ 0.8 by default), these three give a *strong inductive bias toward the phylogenetic baseline* — useful when the GNN has nothing extra to learn on a given dataset. ### 3.5 Within-row cross-trait attention (B3, v0.9.3, opt-in) The encoder above mixes all trait columns into a single hidden vector via a linear projection, which loses per-trait identity. When `use_trait_attention = TRUE`, the model additionally builds a per-trait token sequence: $$ \text{tokens}_{i,j} = \mathbf{W}_v \, x_{i,j} + \mathbf{e}_j, \qquad j = 1, \ldots, p, $$ with $\mathbf{e}_j \in \mathbb{R}^{e_d}$ a learnable positional embedding per trait column (dim `trait_embed_dim`, default 32). One multi-head self-attention block (`n_trait_heads`, default 2) mixes these $p$ tokens within each row $i$, followed by mean-pool to a single $e_d$-dim feature. That feature is concatenated alongside $(x, \text{coords}, \text{covs})$ at the encoder input. The mechanism is intended for trait sets with strong within-row functional coupling that the joint MVN baseline cannot already capture (non-Gaussian, nonlinear, or non-monotone cross-trait structure). **On the BIEN plant bench it did not improve pooled RMSE** — the joint MVN/threshold-joint path already encodes $\mathbf{\Sigma}$ for the dominant trait types, so the second cross-trait mechanism is redundant. **Independent verification ([#106](https://github.com/itchyshin/pigauto/issues/106), code in [#116](https://github.com/itchyshin/pigauto/pull/116)):** a 60-replicate ablation at N=2000 (5 seeds per cell, multi-trait DGP that fires the joint MVN baseline) confirmed `pigauto_OFF` beats both `pigauto_ON` and `pigauto_ON_L0` (the "lazy-optimiser disarmed" arm with `baseline_mu` masked and `lambda_shrink = 0`) — z-RMSE 1.038 vs 1.056 vs 1.057; conformal coverage stable at 0.884–0.887. Even with the lazy-optimiser trap disarmed, the network regresses against the baseline rather than finding new structure. See the NEWS entry for the full table. The flag ships as opt-in because it may help on datasets where the linear $\mathbf{\Sigma}$ assumption is too weak (e.g. functional traits with phase-transition or interaction structure), or in small-N regimes where the joint MVN baseline hasn't yet converged on the cross-trait covariance. Backward-compatible: saved fits without the field default to `use_trait_attention = FALSE` and reconstruct identically. ## 4. Calibration: making the gate a real safety floor After training, pigauto does **not** ship the gate $\sigma(\rho)$ that the optimiser converged to. Instead it overrides $\rho$ with a per-trait *calibrated* gate $r_\text{cal}$ chosen on held-out validation cells. `calibrate_gates()` runs a per-trait grid search of $r \in [0, \mathtt{gate\_cap}]$ minimising val-set reconstruction loss (MSE for continuous; 0/1 loss for discrete with an absolute cell floor and a split-validation cross-check that prevents the GNN from harming baseline accuracy). The output is a single scalar $r_\text{cal}$ per latent column, stored on the fit and used at prediction time. This is the second layer of safety: even if training somehow pushes the learnable gate wide open, calibration on held-out cells can close it back down. In practice, on datasets with strong phylogenetic signal where the GNN cannot improve on BM, $r_\text{cal} \approx 0$ and the prediction is exactly the baseline. ## 5. Uncertainty quantification pigauto exposes three distinct uncertainty mechanisms; they answer different questions and must not be conflated. | Mechanism | Source | Validity | |-----------|--------|----------| | `pred$se` (cont./count/ordinal/prop.) | Conditional-MVN SD from the BM baseline, delta-method back-transformed. | Exact under BM, model-dependent. | | `pred$se` (binary/categorical) | $\min(p, 1{-}p)$ / $1 - \max_k p_k$ — uncertainty score, **not** a Gaussian SE. | Use for ranking/reporting; do **not** plug into Rubin's rules. | | `pred$conformal_lower`, `pred$conformal_upper` | Split conformal residual quantile on the val set: $s = q_{(1-\alpha)}( | y - \hat{y} |_{\text{val}} )$. | Distribution-free 95% marginal coverage regardless of model assumptions. | For multiple-imputation workflows, `multi_impute()` exposes two draw methods. `"conformal"` (default) samples missing cells from $\mathcal{N}(\mu, s/1.96)$ on the transformed scale — calibrated against actual residuals. `"mc_dropout"` runs $M$ stochastic GNN passes in training mode (dropout active) on top of BM-draw inputs, which is wider than conformal but reflects model uncertainty when the gate is fully open. For pooled inference, `pool_mi()` applies Rubin's (1987) rules to $M$ downstream fits and returns a tidy data.frame with `estimate`, `std.error`, `df`, `fmi`, `riv` per term. ## 6. Tree uncertainty: a two-step workflow When the species tree itself is uncertain, the Nakagawa & de Villemereuil (2019, *Syst. Biol.*) algorithm pools across $T$ posterior trees. `multi_impute_trees(traits, trees, m_per_tree = 5L)` performs **step 1** — a full pigauto fit per tree, producing $T \times M$ completed datasets, each tagged with the tree index that produced it. **Step 2** is the user's responsibility: refit the downstream comparative model using the **same tree** that produced each dataset, then pool all $T \times M$ fits via `pool_mi()`: ```r fits <- Map(function(dat, t_idx) { dat$species <- rownames(dat) nlme::gls(y ~ x, correlation = ape::corBrownian(phy = trees[[t_idx]], form = ~species), data = dat, method = "ML") }, mi$datasets, mi$tree_index) pool_mi(fits) ``` Compute is linear in $T$. The 2019 paper's relative-efficiency index typically converges before $T = 50$. ## 7. Putting it together: a worked predictive equation A clean end-to-end summary of what `predict.pigauto_fit()` returns for one continuous trait on one missing cell $i$ in a single-tree, single- imputation call: $$ \hat{y}_i \;=\; \underbrace{(1 - r_\text{cal}) \cdot \mu_i^{\text{BM/joint-MVN}}}_\text{baseline contribution} + \underbrace{r_\text{cal} \cdot \mathrm{delta}_i^{\text{GNN}}}_\text{GNN contribution} + \underbrace{\mathrm{cov\_linear}(u_i)}_\text{linear cov shortcut}, $$ with $\mu_i$ from the joint MVN (or per-column BM fallback), $\mathrm{delta}_i$ from the GNN (transformer blocks + optional within- row cross-trait attention), and $r_\text{cal}$ from validation calibration. The conformal interval is $$ \hat{y}_i \pm s_t, $$ where $s_t$ is the split-conformal residual quantile for trait $t$, giving $\ge 95\%$ marginal coverage on the original scale. ## 8. Where to look in the code | Concept | File | |---------|------| | BM kernel + conditional MVN | `R/bm_internal.R` | | Joint MVN baseline | `R/joint_mvn_baseline.R` | | Threshold-joint (binary + ordinal) | `R/joint_threshold_baseline.R`, `R/liability.R` | | OVR categorical | `R/ovr_categorical.R` | | Graph Transformer Block (B2) | `R/graph_transformer_block.R` | | ResidualPhyloDAE + B3 trait attention | `R/model_residual_dae.R` | | Training loop, gate calibration | `R/fit_pigauto.R`, `R/fit_helpers.R` | | Prediction, conformal intervals | `R/predict_pigauto.R` | | Multi-imputation pooling | `R/multi_impute.R`, `R/pool_mi.R` | | Tree uncertainty workflow | `R/multi_impute_trees.R` | ## References - Felsenstein, J. (1985). *Phylogenies and the comparative method.* AmNat. - Pagel, M. (1999). *Inferring the historical patterns of biological evolution.* Nature. - Bruggeman, J., Heringa, J., & Brandt, B. W. (2009). *PhyloPars: estimation of missing parameter values using phylogeny.* NAR. - Goolsby, E. W., Bruggeman, J., & Ané, C. (2017). *Rphylopars: fast multivariate phylogenetic comparative methods for missing data and within-species variation.* MEE. - Wright, S. (1934). *An analysis of variability in number of digits in an inbred strain of guinea pigs.* Genetics. (Liability model) - Vaswani, A. et al. (2017). *Attention is all you need.* NeurIPS. - Ying, C. et al. (2021). *Do Transformers Really Perform Bad for Graph Representation?* NeurIPS. (Graphormer / multi-scale phylogenetic attention bias.) - Rubin, D. B. (1987). *Multiple Imputation for Nonresponse in Surveys.* - Nakagawa, S., & Freckleton, R. P. (2008, 2011). *Missing inaction: the dangers of ignoring missing data.* TREE / *Model averaging, missing data and multiple imputation.* BES. - Nakagawa, S., & de Villemereuil, P. (2019). *A general method for simultaneously accounting for phylogenetic and species sampling uncertainty via Rubin's rules in comparative analysis.* Syst. Biol. 68(4): 632–641. - Vovk, V., Gammerman, A., & Shafer, G. (2005). *Algorithmic Learning in a Random World.* (Conformal prediction.)