Multi-Head Constraint Implementation for Precision and Scope (final) Below is a

Multi-Head Constraint Implementation for Precision and Scope (final)

Below is a compact, working-style PyTorch implementation you can hand to an engineer, followed by stripped-down pseudocode that shows where losses, routing, and traces plug into training. It treats “extra heads” as constraint heads that (a) run in parallel with ordinary heads, (b) keep their own parameters, (c) expose an audit trace (per-token constraint scores + optional attention maps), and (d) participate in a multi-objective loss (LM + constraint losses).
I’ve kept shapes explicit and avoided magic. Two deployment variants are included:
  • Variant A (additive capacity): increase head count; concat base+constraint heads; one output projection back to d_model.
  • Variant B (constant capacity): keep d_model constant by shrinking per-head size when you add constraint heads (so total concat stays ≈ d_model). This trades parameter growth for latency/control.
python# pytorch>=2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple

class ConstraintConfig:
def __init__(
self,
names: List[str] = (“reciprocity”, “testifiability”, “decidability”),
heads_per_constraint: int = 1,
trace_token_scores: bool = True,
trace_attn_maps: bool = False,
router_hidden: int = 128,
loss_weights: Dict[str, float] = None,
variant: str = “A”, # “A” = additive capacity; “B” = constant capacity
):
self.names = list(names)
self.heads_per_constraint = int(heads_per_constraint)
self.trace_token_scores = trace_token_scores
self.trace_attn_maps = trace_attn_maps
self.router_hidden = router_hidden
self.loss_weights = loss_weights or {name: 1.0 for name in names}
assert variant in (“A”, “B”)
self.variant = variant

class MultiHeadAttentionWithConstraints(nn.Module):
“””
MHA with extra ‘constraint heads’ dedicated to lawful reasoning.
Returns the standard MHA output plus an audit ‘traces’ dict.
“””
def __init__(
self,
d_model: int,
n_heads_base: int,
n_layers: int = 1, # not used here, but kept for parity with block builders
constraint: Optional[ConstraintConfig] = None,
dropout: float = 0.0,
):
super().__init__()
self.d_model = d_model
self.n_heads_base = n_heads_base
self.constraint = constraint or ConstraintConfig()
self.dropout = nn.Dropout(dropout)

# — head accounting —
self.n_constraint_heads = self.constraint.heads_per_constraint * len(self.constraint.names)
self.n_total_heads = self.n_heads_base + self.n_constraint_heads

# Per-head dimensions.
# Variant A: keep dk=dv=d_model//n_heads_base for base heads; use SAME for constraint heads → more concat width.
# Variant B: set dk=dv so that (n_total_heads * dv) ≈ d_model → constant concat width.
if self.constraint.variant == “A”:

= self.dv = d_model // self.n_heads_base
concat_width = self.dv * self.n_total_heads
# grows with extra heads
else:

= self.dv = d_model // self.n_total_heads
concat_width = self.dv * self.n_total_heads
# ~== d_model

# — base head projections —
self.Wq_base = nn.Linear(d_model,

* self.n_heads_base, bias=False)
self.Wk_base = nn.Linear(d_model,

* self.n_heads_base, bias=False)
self.Wv_base = nn.Linear(d_model, self.dv * self.n_heads_base, bias=False)

# — constraint head projections (separate parameterization) —
if self.n_constraint_heads > 0:
self.Wq_con = nn.Linear(d_model,

* self.n_constraint_heads, bias=False)
self.Wk_con = nn.Linear(d_model,

* self.n_constraint_heads, bias=False)
self.Wv_con = nn.Linear(d_model, self.dv * self.n_constraint_heads, bias=False)
else:
self.Wq_con = self.Wk_con = self.Wv_con = None

# One output projection over concatenated head outputs
self.Wo = nn.Linear(concat_width, d_model, bias=False)

# — constraint routers & scorers —
# A simple router that gates constraint heads from a pooled token (e.g., first token or mean pool)
route_in = d_model
self.router = nn.Sequential(
nn.Linear(route_in, self.constraint.router_hidden),
nn.ReLU(),
nn.Linear(self.constraint.router_hidden, self.n_constraint_heads),
)
# Per constraint head token scorer (for audit + auxiliary loss)
# We use an MLP->scalar per token for each constraint head.
self.token_scorers = nn.ModuleList([
nn.Sequential(
nn.Linear(self.dv, self.dv),
nn.ReLU(),
nn.Linear(self.dv, 1)
# scalar per token
) for _ in range(self.n_constraint_heads)
])

# Helper: map constraint names to contiguous head ranges
self._constraint_spans = {}
idx = 0
for name in self.constraint.names:
self._constraint_spans[name] = (idx, idx + self.constraint.heads_per_constraint)
idx += self.constraint.heads_per_constraint

self.scale = (

) ** -0.5 # 1/sqrt(dk)

def _split_heads(self, x: torch.Tensor, n_heads: int, head_dim: int) -> torch.Tensor:
# x: [B, T, n_heads * head_dim] -> [B, n_heads, T, head_dim]
B, T, _ = x.shape
x = x.view(B, T, n_heads, head_dim).transpose(1, 2)
return x
# [B, H, T, D]

def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, H, T, D] -> [B, T, H*D]
B, H, T, D = x.shape
return x.transpose(1, 2).contiguous().view(B, T, H * D)

def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, # shape broadcastable to [B, 1, T, T]
need_weights: bool = False,
router_hint: Optional[torch.Tensor] = None,
# optional [B, d_model] routing context
) -> Tuple[torch.Tensor, Dict]:
“””
x: [B, T, d_model]
returns: (y: [B, T, d_model], traces: dict)
“””
B, T, _ = x.shape

# — projections —
Qb = self._split_heads(self.Wq_base(x), self.n_heads_base,

) # [B, Hb, T, dk]
Kb = self._split_heads(self.Wk_base(x), self.n_heads_base,

)
Vb = self._split_heads(self.Wv_base(x), self.n_heads_base, self.dv)

if self.n_constraint_heads > 0:
Qc = self._split_heads(self.Wq_con(x), self.n_constraint_heads,

) # [B, Hc, T, dk]
Kc = self._split_heads(self.Wk_con(x), self.n_constraint_heads,

)
Vc = self._split_heads(self.Wv_con(x), self.n_constraint_heads, self.dv)
else:
Qc = Kc = Vc = None

# — router gates for constraint heads —
if self.n_constraint_heads > 0:
if router_hint is None:
# Use mean pool over sequence as a cheap context
router_hint = x.mean(dim=1)
# [B, d_model]
gates = torch.sigmoid(self.router(router_hint))
# [B, Hc]
gates = gates.view(B, self.n_constraint_heads, 1, 1)
# broadcast over T,T
else:
gates = None

# — scaled dot-product attention —
def attn(Q, K, V, gate=None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Q,K,V: [B, H, T, D]
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
# [B, H, T, T]
if attn_mask is not None:
scores = scores + attn_mask
# mask contains 0 or -inf
if gate is not None:
# Light-handed way: add log(gate) to diagonal of scores to boost self-focus when gate is low/high.
# Alternatively, multiply V by gate (next line). We do both minimally:
V = V * gate
# [B, H, T, D]
P = torch.softmax(scores, dim=-1)
P = self.dropout(P)
out = torch.matmul(P, V)
# [B, H, T, D]
return out, (P if need_weights else None)

Hb, Pb = attn(Qb, Kb, Vb, None)
if self.n_constraint_heads > 0:
Hc, Pc = attn(Qc, Kc, Vc, gates)
else:
Hc, Pc = None, None

# — concat heads and project —
if Hc is not None:
H_all =

([Hb, Hc], dim=1) # [B, Hb+Hc, T, D]
else:
H_all = Hb
Z = self._combine_heads(H_all)
# [B, T, (Hb+Hc)*Dv]
y = self.Wo(Z)
# [B, T, d_model]

# — traces (audit) —
traces = {“constraint”: {}, “need_weights”: need_weights}
if self.n_constraint_heads > 0:
# token_scores per head: [B, Hc, T]
head_scores = []
for h in range(self.n_constraint_heads):
# Take that head’s token states: [B, 1, T, Dv] -> [B, T, Dv]
h_states = Hc[:, h, :, :]
# [B, T, Dv]
s = self.token_scorers[h](h_states).squeeze(-1)
# [B, T]
head_scores.append(s)
head_scores = torch.stack(head_scores, dim=1)
# [B, Hc, T]
traces[“constraint”][“head_token_scores”] = head_scores
# raw per-head token scalars

# Aggregate to named constraints
for name in self.constraint.names:
lo, hi = self._constraint_spans[name]
# mean across heads in that group
token_scores = head_scores[:, lo:hi, :].mean(dim=1)
# [B, T]
entry = {“token_scores”: token_scores}
if need_weights and self.constraint.trace_attn_maps:
entry[“attn_maps”] = Pc[:, lo:hi, :, :]
# [B, Hc_g, T, T]
traces[“constraint”][name] = entry

# Also expose router gates (per-batch, per-head)
traces[“constraint”][“router_gates”] = gates.squeeze(-1).squeeze(-1)
# [B, Hc]

# Optionally expose base attention maps
if need_weights:
traces[“base_attn_maps”] = Pb
# [B, Hb, T, T]

return y, traces

.no_grad()
def explain(self, traces: Dict, tokens: List[str], constraint_name: str = “reciprocity”) -> str:
“””
Human-readable audit line from token_scores.
“””
if “constraint” not in traces or constraint_name not in traces[“constraint”]:
return “No constraint trace.”
ts = traces[“constraint”][constraint_name][“token_scores”]
# [B, T]
ts = ts[0]
# first in batch
# Make a short explanation by selecting top-k contributing tokens
k = min(5, len(tokens))
top_idx = torch.topk(ts, k=k).indices.tolist()
parts = [f”{tokens[i]}:{ts[i]:.2f}” for i in top_idx]
return f”{constraint_name} focus → ” + “, “.join(parts)

pythonGiven:
d_model
n_heads_base
constraint_names = [reciprocity, testifiability, decidability]
heads_per_constraint = Hc_each
variant ∈ {A=add capacity, B=constant capacity}
λ = constraint_weight (hyperparameter)

Compute:
n_constraint_heads = Hc_each * len(constraint_names)
n_total_heads = n_heads_base + n_constraint_heads

If variant == A:
dk = dv = floor(d_model / n_heads_base) # keep per-head width; concat grows
concat_width = dv * n_total_heads
Else if variant == B:
dk = dv = floor(d_model / n_total_heads)
# shrink per-head width; concat ~ d_model
concat_width = dv * n_total_heads

Parameters:
Base heads:
Wq_base: [d_model, dk * n_heads_base]
Wk_base: [d_model, dk * n_heads_base]
Wv_base: [d_model, dv * n_heads_base]

Constraint heads (separate params):
Wq_con: [d_model, dk * n_constraint_heads]
Wk_con: [d_model, dk * n_constraint_heads]
Wv_con: [d_model, dv * n_constraint_heads]

Output:
Wo: [concat_width, d_model]

Router (for constraint heads):
router: MLP(d_model → hidden → n_constraint_heads)

Token scorers for audit + loss:
For each of n_constraint_heads:
MLP(dv → dv → 1)

Forward(x):
Qb,Kb,Vb = project_and_split(x, base)
Qc,Kc,Vc = project_and_split(x, constraint)

router_hint = mean_pool(x) # or [CLS], or task-specific control
gates = sigmoid(router(router_hint))
# [B, n_constraint_heads]

Hb = attn(Qb, Kb, Vb, mask)
Hc = attn(Qc, Kc, Vc * gates[…,None,None], mask) # gate constraint V or scale scores

H_all = concat(Hb, Hc) # [B, H_total, T, dv]
Z = combine_heads(H_all)
# [B, T, concat_width]
y = Wo(Z)
# [B, T, d_model]

# Traces:
For each constraint head h:
s_h = scorer_h(Hc[h]) -> [B, T, 1] → squeeze → [B, T]
Group {s_h} by constraint name (average across that name’s heads) → token_scores[name]: [B, T]
Return y, traces = { constraint: { name: { token_scores, (attn_maps?) }, head_token_scores, router_gates } }

Loss:
lm_loss = CE(logits, next_tokens)
c_loss = mean_over_names( BCEWithLogits( token_scores[name], targets[name] ) * weight[name] )
total_loss = lm_loss + λ * c_loss

Notes:
– targets[name] can be dense (0..1) from your NLI labelers or binary.
– You can add sparsity or entropy penalties on router_gates if you want heads to specialize.
– For efficiency, you may compute attn_maps only when need_weights=True (eval/audit).

  1. Dedicated representational budgetConstraint heads are architecturally reserved; they can’t be cannibalized by generic correlation pursuit. This injects an inductive bias toward lawful structure (causality/reciprocity/decidability) rather than mere co-occurrence.
  2. Routing = conditional computeThe router turns constraint capacity on/off per input. You get specialization without paying the full compute cost on every token. Add entropy/L1/L0 penalties if you want crisper specialization.
  3. Traces by constructionThe token scorers are cheap MLPs on head outputs. They yield an audit trail (per-token scalars and, if enabled, attention maps). You can serialize these alongside the final answer for explanations and QA.
  4. Training stabilityKeep λ small at first (e.g., 0.1–0.3) and warm-up. If you observe interference with LM loss, try: stop-grad through constraint branches for the first N steps, or attach constraint losses on later layers only, or use feature matching (KL/Huber) between constraint heads and distilled causal teacher features.
  5. Variant selection Variant A if you want maximum capacity and don’t mind a modest parameter bump. Variant B if you must keep latency/params flat—use more heads but narrower per-head dims.
  6. Where to attachBest returns typically come from mid–late layers (where semantics stabilize). Start by adding a single constraint-augmented block near 2/3 depth, then expand if improvements saturate.
pythoncfg = ConstraintConfig(
names=[“reciprocity”, “testifiability”, “decidability”],
heads_per_constraint=1,
trace_token_scores=True,
trace_attn_maps=False,
loss_weights={“reciprocity”:1.0, “testifiability”:0.5, “decidability”:0.5},
variant=”A”,
)

block = TransformerBlockWithConstraint(d_model=1024, n_heads_base=16, mlp_ratio=4, dropout=0.1, constraint=cfg)

# x: [B,T,1024]; attn_mask: broadcastable to [B,1,T,T]
y, traces = block(x, attn_mask=None, need_weights=False)

# During training:
lm_loss = language_model_loss(y, targets_next_tokens)
# your usual CE
c_targets = {
“reciprocity”: recip_labels,
# [B,T] in {0,1} or real
“testifiability”: testif_labels,
# [B,T]
“decidability”: decid_labels,
# [B,T]
}
c_loss = constraint_loss(traces, c_targets, cfg.loss_weights)
total = lm_loss + 0.2 * c_loss
total.backward(); optimizer.step()

  • Drop-in: This is a drop-in replacement for your MHA sub-module; no need to change the rest of the stack.
  • Costs: Extra heads add projection + attention cost. Variant B caps this via smaller per-head dims. Profiling recommended on your target sequence lengths.
  • Data: You already have the NLI pipeline to emit token-wise labels/scores. If some constraints are sparse (few positive tokens), use focal BCE or reweight positives.
  • Eval: Track (i) canonical LM metrics, (ii) constraint F1/AUROC, and (iii) downstream adjudication tasks (the thing you actually care about). The gains should show up in (iii) even when (i) is flat.


Source date (UTC): 2025-08-25 20:38:50 UTC

Original post: https://x.com/i/articles/1960079443239837872

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *