PRISM: Training Data Prototypes for Language Models
Authors:Dan Ley, Research Scientist InternJulius Adebayo, Co-Founder & CEO
Published:December 08, 2025
We have trained PRISM, a family of interpretable language models, to answer the question: when an LLM predicts the next token, which training samples is it relying on?
PRISM traces its prediction to the training data in a single forward pass; the same cost as generating a single token.
Across parameter sizes from 130M to 1.6B, PRISM models stay within 5% of their unconstrained counterparts on validation loss and downstream benchmarks, with negligible impact on training time.
Tracing the language model’s outputs to training data: In the following demo, PRISM-1.6B decomposes each token it generates into contributions across a handful of prototypes.
A prototype is a learned pattern that represents a cluster of similar examples in the training data.
Each colored slice on the right shows a prototype’s contribution to the logit for the sampled token,
Together, the slices add up exactly to the final logit,
Hover over a slice to see the prototype’s broad category (e.g., Medical & Bio), its more specific role (e.g., “physiology”), and its representative training data snippet that most strongly activates it.
Loading prompt analysis…
We show an interactive projection (UMAP) visualization of all 16,384 prototypes that PRISM-1.6B learned.
Overall, we observe a prototype dictionary where a bit more than half of the units specialize on low-level morphology and grammatical scaffolding (~35% and ~20%, respectively), while a large minority capture domain-heavy patterns such as medical and biological language (~7%), science and technology (~5%), institutional and civic text (~6%), social and demographic or family descriptions (~3%), named entities (~5%), environment and climate content (~2%), and finance and economics (~1%).
Other remaining prototypes concentrate on structured artifacts like numbers and time expressions (~6%), boilerplate fragments (~3%), URLs and identifiers (~2%), and remaining miscellaneous patterns (~3%).
Loading visualization...
We now pick a few prototypes and show the training data snippets they map to.
In the interactive below, each card shows a single learned prototype. The header gives its automatically inferred category and name.
For each prototype, we present the top tokens that are most strongly associated with the prototype, and the training data snippets it maps to.
We can directly trace any generated token to prototypes, and from there to the training data.
Loading prototypes…
Intervening on prototypes during text generation
In this demo, we pick one prototype and, at every token, clamp its activation so that its contribution to the sampled token’s logit is forced to be a fixed fraction of PRISM’s original top-1 logit.
This lets us see directly how amplifying or muting a single, training pattern (for example, “clinical trial boilerplate” or “fraction arithmetic”) changes the model’s behavior.
Hover over the text to visualize how the sampled token’s probability shifts as a result of boosting the prototype.
No intervention entries found.
Group prototype intervention, during generation, for science & tech (sky blue) prototypes.
In the demo below, we instead act on an entire labeled category: at each token we inspect the top-16 active prototypes, aggregate the logit signatures of those tagged Science & Tech, and add or subtract a fixed fraction of that aggregate, to amplify or reduce the influence of science and tech patterns wherever they appear in the mixture.
Suppressing this category removes the sky blue highlights and shifts the text toward other patterns (such as institutional or civic language), while boosting it produces more science and tech content, like references to web browsers and email infrastructure.
No intervention data found.
Introduction
Generative AI has a data provenance challenge.
AI labs have paid record-breaking settlements over training data.
Others face ongoing litigation from publishers.
When statutory damages reach exorbitant amounts, the question at the center of these cases becomes urgent: when a model generates an output, what training data is it relying on?
This problem, training data attribution (TDA), matters beyond the courtroom.
Reliable attribution lets us value data appropriately, understand how LLMs solve hard problems, and verify their outputs.
We would prefer a model answering a medical question to rely on journal articles rather than personal blog posts.
PRISM takes a different approach: it ties training data attribution directly to the model architecture.
Every prediction decomposes into a sparse combination of learned prototypes; patterns corresponding to clusters of training examples.
The architecture is explicitly constrained so that every output logit can be faithfully traced back to these clusters.
Consequently, attributing the model’s output back to the training data is a single forward pass.
A medical answer might draw 60% from a prototype grounded in peer-reviewed abstracts; a code completion might trace back to documentation rather than Stack Overflow.
In the sections that follow, we cover PRISM’s architecture and training losses, our automated pipeline for labeling prototypes and retrieving training neighbors, and scaling results from 124M to 1.6B parameters; where PRISM stays within 5% of baseline with under 2% overhead.
PRISM Architecture & Loss Functions
We now discuss the key technical underpinnings of our approach, introducing the prototype matrix, routing rule, residual path, and training losses.
Standard LM heads collapse all learned patterns into a single dense weight matrix; no row or column corresponds to a reusable training pattern.
PRISM asks: what if we made the logit layer an interpretable map of the training data?
We leave the transformer backbone unchanged and only modify the output layer.
Instead of sending the final hidden state zt directly through a dense matrix W, we first express zt as a sparse, non-negative mixture of prototypes plus a residual, then map to logits.
Two components replace the dense LM head:
a bank of prototypes, each intended to automatically learn a recurring pattern in the training data while being strongly tied to specific training instances; and
a sparse mixing mechanism that, given the current hidden state zt, selects a small set of relevant prototypes and combines their contributions to produce the next-token logits, plus a residual term for whatever is not captured by the prototypes.
Informally: the model asks which few prototypes does this context resemble, and how do they score possible next tokens?
Notation
Embedding dimension d, no. of prototypes K, vocabulary size V, training dataset size N.
At step t, the decoder hidden state is zt∈Rd and the vocab logits are ℓt∈RV.
Let P=[p1,…,pK]∈Rd×K denote the prototype codebook and αt∈R≥0K the (sparse) prototype activations. Prototypes live within the model’s final layer embedding.
We write W∈RV×d for the LM’s output projection (optionally tied to embedding).
The most direct way to bring this-looks-like-that into next-token prediction is to treat it as a V-way classification problem: compute K prototype activations and mix them into V vocabulary scores with a dense matrix M∈RV×K, as in ProtoPNet-style classifiers. This adds KV parameters and O(KV) FLOPs per token on top of the O(Kd) prototype similarity cost; with vocabularies V≈50,000, even moderate K already implies tens to hundreds of millions of new weights (e.g., K=2,000⇒KV=100M), making this approach prohibitively expensive at language-model scale.
PRISM’s head instead keeps computation in the model’s embedding space, forming a reconstruction
z^t=k∑αt,kpk
and applying W∈RV×d to obtain logits: Wz^t=(WP)αt. This is functionally equivalent to a ProtoPNet style mixer with M=WP while yielding significant parameter reduction (e.g., d≈500⇒Kd+dV=1M+25M, or 1M with weight-tying), and preserving metric continuity by avoiding a coarse d→K→V switch. Empirically, we find that this reparameterization trains faster and more smoothly. On toy experiments with the TinyStories dataset, the ProtoPNet style head required up to 3× longer wallclock time to reach the same perplexity.
Following the literature, we adopt an autoregressive backbone as input to the prototype layer. Our modifications are restricted to the final layer, so PRISM can be implemented in a way that is compatible with standard transformer training recipes and, in principle, could also be adapted to other sequence models such as diffusion based language models. We train the entire model end-to-end, allowing PRISM to learn its own prototypical representation of inputs.
Positive similarity scoring and top-k routing
Once the backbone GPT model has processed the current input, the prototype layer computes the similarity of the input zt to every prototype in the bank [p1,…,pK]. For each prototype, compute its cosine similarity to the current state, zt
ct,i=∥zt∥2∥pi∥2zt⊤pifor i∈[K]
We apply an optional learned scalar τ∈R>0 to expand the effective dynamic range of cosine scores. Intuitively, we want the model to expose non-negative reasoning : predictions are explained as this-looks-like-that (positive evidence from similar prototypes) rather than this-does-not-look-like-that (subtractive evidence). Thus, we enforce non-negativity via a rectifier:
α~t,i=ReLU(τct,i)
We select the index set Kt=TopK({α~t,i}i=1K,k) and define the final, few-hot similarities
αt,i=α~t,i1{i∈Kt}
This top-k routing ensures that each token prediction is explained in terms of a small, human-readable set of prototypes rather than a dense mixture over all K.
Sparse reconstruction
We would like to reason about a prediction using as few prototypical contexts as possible, to enable crisp interpretability. Sparse activations encourage each prototype to specialize and represent tighter clusters of the training data, which makes it easier to summarize what the model is “thinking” in terms of a handful of distinct patterns i.e. the prototype logit signatures become more fine-grained. Given the k most similar prototypes, we form a k-sparse reconstruction
z^t=Pαt=i∈Kt∑αt,ipi∈Rd
This follows existing SAE literature, which learns sparse dictionaries for hidden states at intermediate layers. In contrast, PRISM learns a sparse dictionary of training-grounded prototypes that directly explain the model’s output logits without a separate decoder. The features learned are also directly tied to groups of training examples (see next section).
Merge and logits
We use a residual merge with the original state to account for parts of the input not reconstructed by prototypes. The residual rt=zt−z^t is computed as the difference between the original zt and the reconstruction z^t (thus, zt′=zt). The vocabulary projection is standard:
Keeping an explicit residual path preserves the expressivity of the original backbone. Rare or input-dependent tokens need not be forced through the prototype dictionary. Measuring how much of each prediction is accounted for by prototypes versus the residual is straightforward.
Faithful Logit decomposition
The PRISM head builds an interpretable logit map at the model’s final layer, ensuring that we can directly quantify the effect and importance of any prototype to any output token by design. By linearity of W, the next-token logits decompose into per-prototype contributions:
ℓt=Wrt+i∈Kt∑αt,i(Wpi)
Each prototype pi thus induces a fixed token–logit signature Wpi∈RV, and the model’s prediction is an explicit, sparse, non-negative mixture over at most k such signatures. This yields additive, causally faithful units that can be ablated or amplified directly at the logit level. When a model predicts a given token, we can recover a given prototype’s exact contribution simply by multiplying its input activation αt,i by its fixed logit signature Wpi (indexed at the predicted token). As a matter of preference, we combine the scalar τ into Wpi when interpreting the prototype signature. This restricts our interpretation of the final logits to a weighted superposition in the range [0,1] of top-k prototype signatures.
Loss functions
Here we detail the loss functions used to train PRISM. Let I(B) denote the index set of token positions across the current macro-batch. Additionally, let di(j)=−c(pi,zj) be the negative cosine distance between prototype i and the token representation at position j∈I(B).
LPRISM=LCE+Clustering LossesLR1+LR2+LRES
Cross-Entropy (LCE).
We use the standard objective
−∣B∣1(x1:T)∈B∑t=1∑T−1logpθ(xt+1∣x≤t)
where pθ(xt+1∣x≤t)=softmax(ℓt)xt+1 and ℓt are the logits computed from the merged state zt′.
Prototype Pull (LR1).
We encourage each prototype to anchor to some token in the batch with
LR1=K1i=1∑Kj∈I(B)mindi(j).
Training-Point Pull (LR2).
Symmetrically, every token position should be close to at least one prototype via
LR2=∣I(B)∣1j∈I(B)∑i∈[K]mindi(j).
Combined, the R1 and R2 terms can be viewed as clustering losses in the backbone LM’s final layer embedding.
Residual (LRES).
We set zt′=z^t+rt with rt:=zt−z^t. We simply minimize the mean-squared residual
LRES=∥rt∥22=∥zt−i∈Kt∑αt,ipi∥22
i.e., the MSE of the mismatch between the prototype reconstruction and the original state.
(Optional) Prototype Diversity (LDIV).
We optionally encourage prototypes to cover diverse representations within the final layer’s embedding, to reduce prototype overlap and encourage specialization. For this setting, we penalize off-diagonal coherence of the ℓ2-normalized prototypes. With p~i=pi/∥pi∥2 and G=P~⊤P~, LDIV=K(K−1)1∑i=jGij2. For K>D, the average squared coherence is lower bounded by the Welch bound
[10]
[10]Welch, 1974. Lower bounds on the maximum cross correlation of signals . Driving LDIV toward this limit spreads prototypes nearly optimally on SD−1 and empirically yields crisper, more distinct roles without harming validation cross-entropy.
Training Data Attribution in a Single Forward Pass
PRISM exposes all quantities needed for attribution during inference. Given hidden state zt:
where Si is the precomputed set of training tokens nearest to prototype i, and μSi is a weighting over that set (commonly uniform). At is fully determined by forward pass values and static mappings. No gradients, no Hessians, no dataset search.
Automated Interpretability Pipeline
PRISM gives us two handles for automation: each prototype is tied to training contexts via its activations, and each has a fixed logit signature Wpi over the vocabulary. We use these to (i) find training snippets each prototype represents, and (ii) assign human-readable labels.
Nearest Neighbor Search
For each prototype, we recover concrete training examples with a single streaming pass over the dataset. We retain the top-L positions with highest activations:
For every token position j, compute αj,i for all prototypes
Maintain a max-heap of size L per prototype storing best matches
Enforce distinct-position constraints to avoid redundant sliding-window variants
This is a one-pass O(NK) procedure with O(KL) memory. Because the LR1 loss pulls each prototype toward training tokens, high-activation neighbors exist by construction. In practice, similarity converges after scanning roughly 1% of training data.
Automatic Labeling
For each prototype, we build a compact “card” containing (i) top tokens from its logit signature and (ii) local contexts where it fires. A small labeling model converts this into human-readable metadata: a short name, a one-line description (e.g., “clinical trial boilerplate”, “Unix timestamps”), and example contexts.
A second pass assigns coarse tags used in visualizations: broad category (Science & Tech, Numbers & Time, URLs & IDs), syntactic role (noun-like, function word, scaffold phrase), and optional domain tags (medical, US universities). This runs offline on learned prototypes and their neighbors.
Performance & Scaling
We now discuss the training procedure and performance details from training PRISM end-to-end across various model sizes.
Scaling from 124M to 1.6B
Overview of PRISM performance compared to an unconstrained GPT model
We train GPT backbones from 124M to 1.6B parameters end-to-end with the prototype layer for one epoch on FineWeb-Edu-10B. PRISM stays within 5% of unconstrained baselines on validation loss and downstream benchmarks across all scales. The prototype layer adds d×K parameters: at GPT-XL scale with K=16384 prototypes, this is 26M parameters (1.7% overhead).
Training time increases by less than 2%. The overhead shrinks as a fraction of total parameters as backbones scale up.
Faithful attribution to training data does not require sacrificing model quality.
Parameter
Small
Medium
Large
XL
Block Size
1024
1024
1024
1024
Embed. Dim.
768
1024
1280
1600
No. Heads
12
16
20
25
No. Layers
12
24
36
48
Total Parameters
124M
355M
774M
1.558B
Dim \ K
4096
8192
16384
32768
768 (S)
3.15M (+2.5%)
6.29M (+5.1%)
12.58M (+10.1%)
25.17M (+20.2%)
1024 (M)
4.19M (+1.2%)
8.39M (+2.4%)
16.78M (+4.7%)
33.55M (+9.5%)
1280 (L)
5.24M (+0.7%)
10.49M (+1.4%)
20.97M (+2.7%)
41.94M (+5.4%)
1600 (XL)
6.55M (+0.4%)
13.11M (+0.8%)
26.21M (+1.7%)
52.43M (+3.4%)
Conclusion
PRISM demonstrates that training data attribution doesn’t have to be a post-hoc approximation bolted onto an opaque model.
By building interpretability into the architecture, we get faithful explanations at the cost of a single forward pass.
At 1.6B parameters, PRISM stays within 5% of baseline performance with under 2% overhead.
The prototype dictionary is inspectable, editable, and directly tied to training data.
This works lays a foundation for language models that can be more easily audited, steered, and whose predictions can be faithfully traced to the training data.
Our results indicate that, at GPT XL scale, there are solutions comfortably within 5% of the original backbone’s performance that satisfy PRISM’s interpretability constraints.
In this view, PRISM does not enforce an accuracy–interpretability tradeoff so much as bias optimization toward a part of the Rashomon set where the logit layer admits a structured, training-data–grounded decomposition into prototypes.