Before locking in the architecture for PredictLM-Mini, we ran six experiments at the 26–57M parameter scale. Five of them lost to a smaller model. This post is the full writeup, including the configurations that didn't work and the gradient-NaN trail that ate three days.
The premise
PredictLM-Base is a 26M-parameter transformer. The natural question before shipping Mini was: is there a better architecture in the 26–57M range that we should be distilling from instead? If we could find a teacher that beats Base by even 2 pp R² on regression, every downstream variant — including Mini — inherits that headroom.
So we ran six experiments. Each got a fresh 26M-parameter base and a controlled change. We held the data mixture, the optimizer schedule, the eval harness, and the random seed fixed.
The six experiments
1. Dual-teacher distillation
Hypothesis: Distill from Base and TabPFN v2 jointly. The model averages two priors and inherits the strengths of both.
Outcome: Worse than Base by 1.4 pp mean R². The two teachers disagree on small-data classification in ways that produce destructive interference in the student's loss landscape. The classification accuracy stays close to the better of the two, but regression suffers visibly.
Verdict: Lost. Single-teacher distillation wins.
2. Set Transformer datapoint attention
Hypothesis: Replace the row-attention block with a Set Transformer's induced set attention block (ISAB). Should improve permutation invariance and stabilize the embedding of the row dimension.
Outcome: Tied with Base, but 1.4× slower per step. The permutation-invariance improvement is real and measurable on synthetic data, but it doesn't transfer to OpenML in any benchmark-visible way. We paid for it in compute and got nothing back.
Verdict: Lost. The standard row-attention block is already permutation-invariant enough in practice.
3. DiffFormer attention
Hypothesis: DiffFormer — differential attention with two query/key projections subtracted — reduces hallucination on long-context LLMs. Tabular FMs have a related problem: spurious feature correlations across context rows.
Outcome: Significantly worse, by 3.2 pp R². The subtraction-based attention is doing the wrong thing for tabular data, where correlated features are usually signal, not noise. DiffFormer is a language-model fix that doesn't generalize.
Verdict: Lost. Same architecture, no DiffFormer.
4. "Go bigger" — 57M combined model
Hypothesis: All of the above + more parameters. 57M params, 16 layers, d_model = 384. Maybe we're underfit.
Outcome: Mean R² delta of +0.4 pp over Base — within sampling noise. The model is twice the size and uses twice the compute per token. Not paying for itself.
Verdict: Lost. We're not underfit at 26M for this data distribution.
5. MPS-throughput mitigations
Hypothesis: A separate experiment — can we train the base on a beefy M-series Mac at acceptable throughput? Useful for fast local iteration.
Outcome: MPS throughput is roughly 0.35× of a single Tesla T4, even after the Flash Attention port and the manual torch.compile workarounds. The gradient-NaN issue we hit (see below) was an MPS-specific bfloat16 path that ships with PyTorch 2.5.
Verdict: Conceded. CUDA-only for training.
6. Teacher-free training
Hypothesis: Skip the teacher entirely. Train Mini from scratch with the same data + loss as Base.
Outcome: Diverged after 12k steps with a regression-head gradient explosion. We've seen this before in earlier PredictLM versions — the BarDistribution head needs the smoothing pressure of a teacher distribution to stay stable in float16. With a sharper from-scratch target, the head amplifies the gradients.
Verdict: Lost. Distillation is doing real work, not just incremental improvement.
The gradient-NaN debugging
Three of the runs above hit gradient NaNs in the first 30k steps. The trail:
- First suspicion: mixed-precision regression head divergence. We've fixed this before by tanh-bounding
log_var. But the NaNs were on classification too. - Second suspicion: a bad batch. We added per-batch loss-watchers and
torch.isnanchecks before theoptimizer.step(). Whenever a batch produced NaN, we skipped it. This worked for the first few hours, then NaN rate climbed and convergence stalled. - Third suspicion: the data loader. Looking at the skipped batches, we noticed they all had context-row counts in a narrow band (~480 rows) and high feature dimensionality (>200). The attention scores at this regime were exceeding bfloat16 max for a small fraction of heads.
- The fix: force float32 across the attention path. The forward pass costs ~12% more memory and ~6% more wall-clock. Mini's full training run is 3.3 hours. We can afford 6%.
Lesson: Don't paper over gradient NaNs with batch skipping. The skipped batches are a signal about which inputs your model can't handle in your chosen precision. Track them, find the regime, fix it at the source.
What we shipped
The architecture from PredictLM-Base, unchanged. ALBERT-style parameter sharing for Mini, distillation procedure from the same teacher, float32 attention. The simplest configuration we could justify, with one round of ablations and no flourishes.
The full experimental procedure is documented in this post; specific configs and the training repo are kept internal for now.
Why publish the losses
Two reasons. The first is selfish — writing up what didn't work helps us not repeat it. The second is structural: a lab that only publishes wins is doing PR, not research. A 5/6 loss rate is the realistic outcome of architecture search at this scale, and pretending otherwise distorts the field's sense of what's tractable. If you want to read about a six-month run where everything worked, this is the wrong blog.