Multi-Head Constraint Implementation for Precision and Scope (final)
-
Variant A (additive capacity): increase head count; concat base+constraint heads; one output projection back to d_model.
-
Variant B (constant capacity): keep d_model constant by shrinking per-head size when you add constraint heads (so total concat stays ≈ d_model). This trades parameter growth for latency/control.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
class ConstraintConfig:
def __init__(
self,
names: List[str] = (“reciprocity”, “testifiability”, “decidability”),
heads_per_constraint: int = 1,
trace_token_scores: bool = True,
trace_attn_maps: bool = False,
router_hidden: int = 128,
loss_weights: Dict[str, float] = None,
variant: str = “A”, # “A” = additive capacity; “B” = constant capacity
):
self.names = list(names)
self.heads_per_constraint = int(heads_per_constraint)
self.trace_token_scores = trace_token_scores
self.trace_attn_maps = trace_attn_maps
self.router_hidden = router_hidden
self.loss_weights = loss_weights or {name: 1.0 for name in names}
assert variant in (“A”, “B”)
self.variant = variant
class MultiHeadAttentionWithConstraints(nn.Module):
“””
MHA with extra ‘constraint heads’ dedicated to lawful reasoning.
Returns the standard MHA output plus an audit ‘traces’ dict.
“””
def __init__(
self,
d_model: int,
n_heads_base: int,
n_layers: int = 1, # not used here, but kept for parity with block builders
constraint: Optional[ConstraintConfig] = None,
dropout: float = 0.0,
):
super().__init__()
self.d_model = d_model
self.n_heads_base = n_heads_base
self.constraint = constraint or ConstraintConfig()
self.dropout = nn.Dropout(dropout)
# — head accounting —
self.n_constraint_heads = self.constraint.heads_per_constraint * len(self.constraint.names)
self.n_total_heads = self.n_heads_base + self.n_constraint_heads
# Per-head dimensions.
# Variant A: keep dk=dv=d_model//n_heads_base for base heads; use SAME for constraint heads → more concat width.
# Variant B: set dk=dv so that (n_total_heads * dv) ≈ d_model → constant concat width.
if self.constraint.variant == “A”:
= self.dv = d_model // self.n_heads_base
concat_width = self.dv * self.n_total_heads # grows with extra heads
else:
= self.dv = d_model // self.n_total_heads
concat_width = self.dv * self.n_total_heads # ~== d_model
# — base head projections —
self.Wq_base = nn.Linear(d_model,
* self.n_heads_base, bias=False)
self.Wk_base = nn.Linear(d_model,
* self.n_heads_base, bias=False)
self.Wv_base = nn.Linear(d_model, self.dv * self.n_heads_base, bias=False)
# — constraint head projections (separate parameterization) —
if self.n_constraint_heads > 0:
self.Wq_con = nn.Linear(d_model,
* self.n_constraint_heads, bias=False)
self.Wk_con = nn.Linear(d_model,
* self.n_constraint_heads, bias=False)
self.Wv_con = nn.Linear(d_model, self.dv * self.n_constraint_heads, bias=False)
else:
self.Wq_con = self.Wk_con = self.Wv_con = None
# One output projection over concatenated head outputs
self.Wo = nn.Linear(concat_width, d_model, bias=False)
# — constraint routers & scorers —
# A simple router that gates constraint heads from a pooled token (e.g., first token or mean pool)
route_in = d_model
self.router = nn.Sequential(
nn.Linear(route_in, self.constraint.router_hidden),
nn.ReLU(),
nn.Linear(self.constraint.router_hidden, self.n_constraint_heads),
)
# Per constraint head token scorer (for audit + auxiliary loss)
# We use an MLP->scalar per token for each constraint head.
self.token_scorers = nn.ModuleList([
nn.Sequential(
nn.Linear(self.dv, self.dv),
nn.ReLU(),
nn.Linear(self.dv, 1) # scalar per token
) for _ in range(self.n_constraint_heads)
])
# Helper: map constraint names to contiguous head ranges
self._constraint_spans = {}
idx = 0
for name in self.constraint.names:
self._constraint_spans[name] = (idx, idx + self.constraint.heads_per_constraint)
idx += self.constraint.heads_per_constraint
self.scale = (
) ** -0.5 # 1/sqrt(dk)
def _split_heads(self, x: torch.Tensor, n_heads: int, head_dim: int) -> torch.Tensor:
# x: [B, T, n_heads * head_dim] -> [B, n_heads, T, head_dim]
B, T, _ = x.shape
x = x.view(B, T, n_heads, head_dim).transpose(1, 2)
return x # [B, H, T, D]
def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, H, T, D] -> [B, T, H*D]
B, H, T, D = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * D)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, # shape broadcastable to [B, 1, T, T]
need_weights: bool = False,
router_hint: Optional[torch.Tensor] = None, # optional [B, d_model] routing context
) -> Tuple[torch.Tensor, Dict]:
“””
x: [B, T, d_model]
returns: (y: [B, T, d_model], traces: dict)
“””
B, T, _ = x.shape
# — projections —
Qb = self._split_heads(self.Wq_base(x), self.n_heads_base,
) # [B, Hb, T, dk]
Kb = self._split_heads(self.Wk_base(x), self.n_heads_base,
)
Vb = self._split_heads(self.Wv_base(x), self.n_heads_base, self.dv)
if self.n_constraint_heads > 0:
Qc = self._split_heads(self.Wq_con(x), self.n_constraint_heads,
) # [B, Hc, T, dk]
Kc = self._split_heads(self.Wk_con(x), self.n_constraint_heads,
)
Vc = self._split_heads(self.Wv_con(x), self.n_constraint_heads, self.dv)
else:
Qc = Kc = Vc = None
# — router gates for constraint heads —
if self.n_constraint_heads > 0:
if router_hint is None:
# Use mean pool over sequence as a cheap context
router_hint = x.mean(dim=1) # [B, d_model]
gates = torch.sigmoid(self.router(router_hint)) # [B, Hc]
gates = gates.view(B, self.n_constraint_heads, 1, 1) # broadcast over T,T
else:
gates = None
# — scaled dot-product attention —
def attn(Q, K, V, gate=None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Q,K,V: [B, H, T, D]
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, T, T]
if attn_mask is not None:
scores = scores + attn_mask # mask contains 0 or -inf
if gate is not None:
# Light-handed way: add log(gate) to diagonal of scores to boost self-focus when gate is low/high.
# Alternatively, multiply V by gate (next line). We do both minimally:
V = V * gate # [B, H, T, D]
P = torch.softmax(scores, dim=-1)
P = self.dropout(P)
out = torch.matmul(P, V) # [B, H, T, D]
return out, (P if need_weights else None)
Hb, Pb = attn(Qb, Kb, Vb, None)
if self.n_constraint_heads > 0:
Hc, Pc = attn(Qc, Kc, Vc, gates)
else:
Hc, Pc = None, None
# — concat heads and project —
if Hc is not None:
H_all =
([Hb, Hc], dim=1) # [B, Hb+Hc, T, D]
else:
H_all = Hb
Z = self._combine_heads(H_all) # [B, T, (Hb+Hc)*Dv]
y = self.Wo(Z) # [B, T, d_model]
# — traces (audit) —
traces = {“constraint”: {}, “need_weights”: need_weights}
if self.n_constraint_heads > 0:
# token_scores per head: [B, Hc, T]
head_scores = []
for h in range(self.n_constraint_heads):
# Take that head’s token states: [B, 1, T, Dv] -> [B, T, Dv]
h_states = Hc[:, h, :, :] # [B, T, Dv]
s = self.token_scorers[h](h_states).squeeze(-1) # [B, T]
head_scores.append(s)
head_scores = torch.stack(head_scores, dim=1) # [B, Hc, T]
traces[“constraint”][“head_token_scores”] = head_scores # raw per-head token scalars
# Aggregate to named constraints
for name in self.constraint.names:
lo, hi = self._constraint_spans[name]
# mean across heads in that group
token_scores = head_scores[:, lo:hi, :].mean(dim=1) # [B, T]
entry = {“token_scores”: token_scores}
if need_weights and self.constraint.trace_attn_maps:
entry[“attn_maps”] = Pc[:, lo:hi, :, :] # [B, Hc_g, T, T]
traces[“constraint”][name] = entry
# Also expose router gates (per-batch, per-head)
traces[“constraint”][“router_gates”] = gates.squeeze(-1).squeeze(-1) # [B, Hc]
# Optionally expose base attention maps
if need_weights:
traces[“base_attn_maps”] = Pb # [B, Hb, T, T]
return y, traces
.no_grad()
def explain(self, traces: Dict, tokens: List[str], constraint_name: str = “reciprocity”) -> str:
“””
Human-readable audit line from token_scores.
“””
if “constraint” not in traces or constraint_name not in traces[“constraint”]:
return “No constraint trace.”
ts = traces[“constraint”][constraint_name][“token_scores”] # [B, T]
ts = ts[0] # first in batch
# Make a short explanation by selecting top-k contributing tokens
k = min(5, len(tokens))
top_idx = torch.topk(ts, k=k).indices.tolist()
parts = [f”{tokens[i]}:{ts[i]:.2f}” for i in top_idx]
return f”{constraint_name} focus → ” + “, “.join(parts)
d_model
n_heads_base
constraint_names = [reciprocity, testifiability, decidability]
heads_per_constraint = Hc_each
variant ∈ {A=add capacity, B=constant capacity}
λ = constraint_weight (hyperparameter)
Compute:
n_constraint_heads = Hc_each * len(constraint_names)
n_total_heads = n_heads_base + n_constraint_heads
If variant == A:
dk = dv = floor(d_model / n_heads_base) # keep per-head width; concat grows
concat_width = dv * n_total_heads
Else if variant == B:
dk = dv = floor(d_model / n_total_heads) # shrink per-head width; concat ~ d_model
concat_width = dv * n_total_heads
Parameters:
Base heads:
Wq_base: [d_model, dk * n_heads_base]
Wk_base: [d_model, dk * n_heads_base]
Wv_base: [d_model, dv * n_heads_base]
Constraint heads (separate params):
Wq_con: [d_model, dk * n_constraint_heads]
Wk_con: [d_model, dk * n_constraint_heads]
Wv_con: [d_model, dv * n_constraint_heads]
Output:
Wo: [concat_width, d_model]
Router (for constraint heads):
router: MLP(d_model → hidden → n_constraint_heads)
Token scorers for audit + loss:
For each of n_constraint_heads:
MLP(dv → dv → 1)
Forward(x):
Qb,Kb,Vb = project_and_split(x, base)
Qc,Kc,Vc = project_and_split(x, constraint)
router_hint = mean_pool(x) # or [CLS], or task-specific control
gates = sigmoid(router(router_hint)) # [B, n_constraint_heads]
Hb = attn(Qb, Kb, Vb, mask)
Hc = attn(Qc, Kc, Vc * gates[…,None,None], mask) # gate constraint V or scale scores
H_all = concat(Hb, Hc) # [B, H_total, T, dv]
Z = combine_heads(H_all) # [B, T, concat_width]
y = Wo(Z) # [B, T, d_model]
# Traces:
For each constraint head h:
s_h = scorer_h(Hc[h]) -> [B, T, 1] → squeeze → [B, T]
Group {s_h} by constraint name (average across that name’s heads) → token_scores[name]: [B, T]
Return y, traces = { constraint: { name: { token_scores, (attn_maps?) }, head_token_scores, router_gates } }
Loss:
lm_loss = CE(logits, next_tokens)
c_loss = mean_over_names( BCEWithLogits( token_scores[name], targets[name] ) * weight[name] )
total_loss = lm_loss + λ * c_loss
Notes:
– targets[name] can be dense (0..1) from your NLI labelers or binary.
– You can add sparsity or entropy penalties on router_gates if you want heads to specialize.
– For efficiency, you may compute attn_maps only when need_weights=True (eval/audit).
-
Dedicated representational budgetConstraint heads are architecturally reserved; they can’t be cannibalized by generic correlation pursuit. This injects an inductive bias toward lawful structure (causality/reciprocity/decidability) rather than mere co-occurrence.
-
Routing = conditional computeThe router turns constraint capacity on/off per input. You get specialization without paying the full compute cost on every token. Add entropy/L1/L0 penalties if you want crisper specialization.
-
Traces by constructionThe token scorers are cheap MLPs on head outputs. They yield an audit trail (per-token scalars and, if enabled, attention maps). You can serialize these alongside the final answer for explanations and QA.
-
Training stabilityKeep λ small at first (e.g., 0.1–0.3) and warm-up. If you observe interference with LM loss, try: stop-grad through constraint branches for the first N steps, or attach constraint losses on later layers only, or use feature matching (KL/Huber) between constraint heads and distilled causal teacher features.
-
Variant selection Variant A if you want maximum capacity and don’t mind a modest parameter bump. Variant B if you must keep latency/params flat—use more heads but narrower per-head dims.
-
Where to attachBest returns typically come from mid–late layers (where semantics stabilize). Start by adding a single constraint-augmented block near 2/3 depth, then expand if improvements saturate.
names=[“reciprocity”, “testifiability”, “decidability”],
heads_per_constraint=1,
trace_token_scores=True,
trace_attn_maps=False,
loss_weights={“reciprocity”:1.0, “testifiability”:0.5, “decidability”:0.5},
variant=”A”,
)
block = TransformerBlockWithConstraint(d_model=1024, n_heads_base=16, mlp_ratio=4, dropout=0.1, constraint=cfg)
# x: [B,T,1024]; attn_mask: broadcastable to [B,1,T,T]
y, traces = block(x, attn_mask=None, need_weights=False)
# During training:
lm_loss = language_model_loss(y, targets_next_tokens) # your usual CE
c_targets = {
“reciprocity”: recip_labels, # [B,T] in {0,1} or real
“testifiability”: testif_labels, # [B,T]
“decidability”: decid_labels, # [B,T]
}
c_loss = constraint_loss(traces, c_targets, cfg.loss_weights)
total = lm_loss + 0.2 * c_loss
total.backward(); optimizer.step()
-
Drop-in: This is a drop-in replacement for your MHA sub-module; no need to change the rest of the stack.
-
Costs: Extra heads add projection + attention cost. Variant B caps this via smaller per-head dims. Profiling recommended on your target sequence lengths.
-
Data: You already have the NLI pipeline to emit token-wise labels/scores. If some constraints are sparse (few positive tokens), use focal BCE or reweight positives.
-
Eval: Track (i) canonical LM metrics, (ii) constraint F1/AUROC, and (iii) downstream adjudication tasks (the thing you actually care about). The gains should show up in (iii) even when (i) is flat.
Source date (UTC): 2025-08-25 20:38:50 UTC
Original post: https://x.com/i/articles/1960079443239837872
Leave a Reply