ugld

UGLD: Uncertainty-Gated Lexical Decoding

Public API:

  • UGLD_Towards, UGLD_Against (HuggingFace LogitsProcessor)
  • UGLDTowardsConfig, UGLDAgainstConfig (configuration dataclasses)
 1"""
 2UGLD: Uncertainty-Gated Lexical Decoding
 3
 4Public API:
 5- UGLD_Towards, UGLD_Against (HuggingFace LogitsProcessor)
 6- UGLDTowardsConfig, UGLDAgainstConfig (configuration dataclasses)
 7"""
 8
 9from .ugld import (
10    UGLD_Towards,
11    UGLD_Against,
12    UGLDTowardsConfig,
13    UGLDAgainstConfig,
14)
15
16__all__ = [
17    "UGLD_Towards",
18    "UGLD_Against",
19    "UGLDTowardsConfig",
20    "UGLDAgainstConfig",
21]
22
23__version__ = "1.0.0"
class UGLD_Towards(transformers.generation.logits_process.LogitsProcessor):
137class UGLD_Towards(LogitsProcessor):
138    """Condition generation *towards* a predefined vocabulary (UGLD-t).
139
140    At each decoding step the model's next-token distribution *p* is mixed
141    with a conditioning prior *q* that concentrates probability mass on the
142    *green* tokens.  The mixing strength is gated by the model's predictive
143    uncertainty, measured via Shannon entropy, so that intervention is strong
144    when the model is uncertain and negligible when it is confident.
145
146    Formally, at each step:
147
148    .. code-block:: text
149
150        p  = SoftMax(z)               # current model distribution
151        H  = -Σ p_i log p_i           # Shannon entropy
152        φ  = σ((H - τ) / s)           # uncertainty gate ∈ (0, 1)
153        α  = α_max · φ                # effective mixing coefficient
154        p' = (1 − α) p + α q          # conditioned distribution
155
156    The output is ``log(p')``; because ``SoftMax(log(p')) = p'``, this is a
157    valid drop-in replacement for the raw logits expected by the HuggingFace
158    generation pipeline.
159
160    Args:
161        config: A :class:`UGLDTowardsConfig` instance specifying the green
162            vocabulary and all hyperparameters.
163
164    Raises:
165        ValueError: If ``config.alpha_max`` is outside ``[0, 1]`` or
166            ``config.topk`` is not positive.
167
168    Example::
169
170        from transformers import LogitsProcessorList
171        from ugld import UGLD_Towards, UGLDTowardsConfig
172
173        processor = UGLD_Towards(UGLDTowardsConfig(
174            green_token_ids=green_ids,
175            alpha_max=0.5,
176            tau=1.0,
177            s=0.3,
178            prior="renorm",
179        ))
180        out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
181    """
182
183    def __init__(self, config: UGLDTowardsConfig):
184        super().__init__()
185        if not (0.0 <= config.alpha_max <= 1.0):
186            raise ValueError("alpha_max must be in [0, 1].")
187        if config.topk <= 0:
188            raise ValueError("topk must be > 0.")
189        self.cfg = config
190
191        # Cache for uniform q (depends on vocab size/device/dtype).
192        self._uniform_q: Optional[torch.Tensor] = None
193        self._uniform_meta = None  # (V, device, dtype)
194
195    def _uniform_prior(self, V: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
196        """Build (and cache) a uniform prior over the green vocabulary.
197
198        Returns a 1-D tensor of shape ``[V]`` where each green token receives
199        probability ``1 / |G|`` and all other tokens receive ``0``.  The
200        result is cached and reused as long as *V*, *device*, and *dtype*
201        remain unchanged.
202
203        Args:
204            V: Vocabulary size.
205            device: Target device.
206            dtype: Target floating-point dtype.
207
208        Returns:
209            Prior tensor of shape ``[V]``.
210        """
211        meta = (V, device, dtype)
212        if self._uniform_q is not None and self._uniform_meta == meta:
213            return self._uniform_q
214
215        green = _valid_token_ids(self.cfg.green_token_ids, V, device)
216        q = torch.zeros((V,), dtype=dtype, device=device)
217        if green.numel() > 0:
218            q[green] = 1.0 / float(green.numel())
219
220        self._uniform_q = q
221        self._uniform_meta = meta
222        return q
223
224    @torch.no_grad()
225    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
226        """Apply UGLD-t to a batch of logits.
227
228        Args:
229            input_ids: Previously generated token ids, shape ``[B, T]``.
230                Not used directly but required by the HuggingFace
231                ``LogitsProcessor`` interface.
232            scores: Raw logits produced by the model, shape ``[B, V]``.
233
234        Returns:
235            Modified log-probabilities of shape ``[B, V]``.  Applying
236            ``SoftMax`` to the output yields the conditioned distribution
237            ``p' = (1 − α) p + α q``.  If no valid green tokens exist the
238            original *scores* are returned unchanged.
239        """
240        B, V = scores.shape
241        device, dtype = scores.device, scores.dtype
242
243        # Base distribution p over vocab.
244        p = torch.softmax(scores, dim=-1)  # [B, V]
245
246        # Entropy gate.
247        H = _entropy_from_probs(p, self.cfg.eps)         # [B]
248        phi = _gate_from_entropy(H, self.cfg.tau, self.cfg.s)  # [B] in [0,1]
249        alpha = (self.cfg.alpha_max * phi).clamp(0.0, 1.0)     # [B]
250
251        # Valid green token ids in this vocab.
252        green = _valid_token_ids(self.cfg.green_token_ids, V, device)
253        if green.numel() == 0:
254            # No valid green tokens -> no-op.
255            return scores
256
257        # Build q: [B, V]
258        if self.cfg.prior == "uniform":
259            q = self._uniform_prior(V, device, dtype).unsqueeze(0).expand(B, V)
260
261        else:
262            q = torch.zeros_like(p)  # [B, V]
263            mass_g = p.index_select(dim=-1, index=green)  # [B, |G|]
264
265            if self.cfg.prior == "renorm":
266                denom = mass_g.sum(dim=-1, keepdim=True).clamp_min(self.cfg.eps)
267                q_g = mass_g / denom  # [B, |G|]
268
269            elif self.cfg.prior == "topk":
270                k = min(self.cfg.topk, green.numel())
271                _, idx = torch.topk(mass_g, k=k, dim=-1)  # [B, k] indices into green
272                chosen = green.index_select(0, idx.reshape(-1)).reshape(B, k)  # [B, k]
273                # Uniform over chosen top-k.
274                q.scatter_(
275                    dim=-1,
276                    index=chosen,
277                    src=torch.full((B, k), 1.0 / float(k), device=device, dtype=dtype),
278                )
279                # Done.
280                p_prime = (1.0 - alpha.unsqueeze(-1)) * p + alpha.unsqueeze(-1) * q
281                return (p_prime.clamp_min(self.cfg.eps)).log()
282
283            else:
284                raise ValueError(f"Unknown prior='{self.cfg.prior}'. Use: uniform|topk|renorm")
285
286            # Scatter q_g back into vocab positions.
287            q.scatter_(dim=-1, index=green.unsqueeze(0).expand(B, -1), src=q_g)
288
289        # Mix in probability space (convex combination).
290        p_prime = (1.0 - alpha.unsqueeze(-1)) * p + alpha.unsqueeze(-1) * q
291        return (p_prime.clamp_min(self.cfg.eps)).log()

Condition generation towards a predefined vocabulary (UGLD-t).

At each decoding step the model's next-token distribution p is mixed with a conditioning prior q that concentrates probability mass on the green tokens. The mixing strength is gated by the model's predictive uncertainty, measured via Shannon entropy, so that intervention is strong when the model is uncertain and negligible when it is confident.

Formally, at each step:

p  = SoftMax(z)               # current model distribution
H  = -Σ p_i log p_i           # Shannon entropy
φ  = σ((H - τ) / s)           # uncertainty gate ∈ (0, 1)
α  = α_max · φ                # effective mixing coefficient
p' = (1 − α) p + α q          # conditioned distribution

The output is log(p'); because SoftMax(log(p')) = p', this is a valid drop-in replacement for the raw logits expected by the HuggingFace generation pipeline.

Args: config: A UGLDTowardsConfig instance specifying the green vocabulary and all hyperparameters.

Raises: ValueError: If config.alpha_max is outside [0, 1] or config.topk is not positive.

Example::

from transformers import LogitsProcessorList
from ugld import UGLD_Towards, UGLDTowardsConfig

processor = UGLD_Towards(UGLDTowardsConfig(
    green_token_ids=green_ids,
    alpha_max=0.5,
    tau=1.0,
    s=0.3,
    prior="renorm",
))
out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
UGLD_Towards(config: UGLDTowardsConfig)
183    def __init__(self, config: UGLDTowardsConfig):
184        super().__init__()
185        if not (0.0 <= config.alpha_max <= 1.0):
186            raise ValueError("alpha_max must be in [0, 1].")
187        if config.topk <= 0:
188            raise ValueError("topk must be > 0.")
189        self.cfg = config
190
191        # Cache for uniform q (depends on vocab size/device/dtype).
192        self._uniform_q: Optional[torch.Tensor] = None
193        self._uniform_meta = None  # (V, device, dtype)
cfg
class UGLD_Against(transformers.generation.logits_process.LogitsProcessor):
334class UGLD_Against(LogitsProcessor):
335    """Condition generation *against* a predefined vocabulary (UGLD-a).
336
337    At each decoding step a penalty is subtracted from the logits of *red*
338    tokens.  The penalty strength is gated by the model's predictive
339    uncertainty so that suppression is strong when the model is uncertain and
340    negligible when it is confident.  Because the penalty is applied in logit
341    space, the output remains unnormalised logits and can be passed directly
342    to subsequent processors or sampling routines.
343
344    Formally, at each step:
345
346    .. code-block:: text
347
348        p  = SoftMax(z)               # current model distribution
349        H  = -Σ p_i log p_i           # Shannon entropy
350        φ  = σ((H - τ) / s)           # uncertainty gate ∈ (0, 1)
351        λ  = λ_max · φ                # effective penalty strength
352        z' = z − λ r                  # penalised logits
353
354    where *r* is a non-negative weight vector supported on the red tokens.
355
356    Args:
357        config: A :class:`UGLDAgainstConfig` instance specifying the red
358            vocabulary and all hyperparameters.
359
360    Raises:
361        ValueError: If ``config.lambda_max < 0`` or ``config.fixed_r <= 0``.
362
363    Example::
364
365        from transformers import LogitsProcessorList
366        from ugld import UGLD_Against, UGLDAgainstConfig
367
368        processor = UGLD_Against(UGLDAgainstConfig(
369            red_token_ids=red_ids,
370            lambda_max=4.0,
371            tau=1.0,
372            s=0.3,
373            weights="fixed",
374        ))
375        out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
376    """
377
378    def __init__(self, config: UGLDAgainstConfig):
379        super().__init__()
380        if config.lambda_max < 0:
381            raise ValueError("lambda_max must be >= 0.")
382        if config.fixed_r <= 0:
383            raise ValueError("fixed_r must be > 0 (weights must be positive on red tokens).")
384        self.cfg = config
385
386    @torch.no_grad()
387    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
388        """Apply UGLD-a to a batch of logits.
389
390        Args:
391            input_ids: Previously generated token ids, shape ``[B, T]``.
392                Not used directly but required by the HuggingFace
393                ``LogitsProcessor`` interface.
394            scores: Raw logits produced by the model, shape ``[B, V]``.
395
396        Returns:
397            Penalised logits of shape ``[B, V]``, equal to
398            ``z' = z − λ r``.  If no valid red tokens exist, or if
399            ``lambda_max`` is zero, the original *scores* are returned
400            unchanged.
401        """
402        B, V = scores.shape
403        device = scores.device
404
405        red = _valid_token_ids(self.cfg.red_token_ids, V, device)
406        if red.numel() == 0 or self.cfg.lambda_max == 0.0:
407            return scores
408
409        # Compute entropy gate from current model distribution p.
410        p = torch.softmax(scores, dim=-1)  # [B, V]
411        H = _entropy_from_probs(p, self.cfg.eps)  # [B]
412        phi = _gate_from_entropy(H, self.cfg.tau, self.cfg.s)  # [B]
413        lam = (self.cfg.lambda_max * phi).clamp_min(0.0)  # [B]
414
415        # Build r: [B, V]
416        r = torch.zeros_like(scores)
417
418        if self.cfg.weights == "fixed":
419            r[:, red] = self.cfg.fixed_r
420
421        elif self.cfg.weights == "dynamic_minmax":
422            pr = p.index_select(dim=-1, index=red)  # [B, |R|]
423            pr_min = pr.min(dim=-1, keepdim=True).values
424            pr_max = pr.max(dim=-1, keepdim=True).values
425            f = (pr - pr_min) / (pr_max - pr_min + self.cfg.eps)  # [B, |R|] in [0,1]
426            r_r = 1.0 + f  # [B, |R|] in [1,2]
427            r.scatter_(dim=-1, index=red.unsqueeze(0).expand(B, -1), src=r_r)
428
429        else:
430            raise ValueError(f"Unknown weights='{self.cfg.weights}'. Use: fixed|dynamic_minmax")
431
432        # Penalize in logit space: z' = z - λ r
433        return scores - lam.unsqueeze(-1) * r

Condition generation against a predefined vocabulary (UGLD-a).

At each decoding step a penalty is subtracted from the logits of red tokens. The penalty strength is gated by the model's predictive uncertainty so that suppression is strong when the model is uncertain and negligible when it is confident. Because the penalty is applied in logit space, the output remains unnormalised logits and can be passed directly to subsequent processors or sampling routines.

Formally, at each step:

p  = SoftMax(z)               # current model distribution
H  = -Σ p_i log p_i           # Shannon entropy
φ  = σ((H - τ) / s)           # uncertainty gate ∈ (0, 1)
λ  = λ_max · φ                # effective penalty strength
z' = z − λ r                  # penalised logits

where r is a non-negative weight vector supported on the red tokens.

Args: config: A UGLDAgainstConfig instance specifying the red vocabulary and all hyperparameters.

Raises: ValueError: If config.lambda_max < 0 or config.fixed_r <= 0.

Example::

from transformers import LogitsProcessorList
from ugld import UGLD_Against, UGLDAgainstConfig

processor = UGLD_Against(UGLDAgainstConfig(
    red_token_ids=red_ids,
    lambda_max=4.0,
    tau=1.0,
    s=0.3,
    weights="fixed",
))
out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
UGLD_Against(config: UGLDAgainstConfig)
378    def __init__(self, config: UGLDAgainstConfig):
379        super().__init__()
380        if config.lambda_max < 0:
381            raise ValueError("lambda_max must be >= 0.")
382        if config.fixed_r <= 0:
383            raise ValueError("fixed_r must be > 0 (weights must be positive on red tokens).")
384        self.cfg = config
cfg
@dataclass(frozen=True)
class UGLDTowardsConfig:
 96@dataclass(frozen=True)
 97class UGLDTowardsConfig:
 98    """Configuration for :class:`UGLD_Towards`."""
 99
100    green_token_ids: Sequence[int]
101    """Token ids that form the *green* vocabulary — the set of tokens the model
102    is encouraged to generate.  Duplicates and out-of-range ids are ignored at
103    runtime."""
104
105    alpha_max: float = 0.25
106    """Maximum mixing coefficient α ∈ [0, 1].  The effective α at each step is
107    ``alpha_max * φ(p)``, so the actual intervention is always at most
108    *alpha_max*."""
109
110    tau: float = 3.0
111    """Entropy threshold τ for the gate φ.  The gate is ~0.5 when the
112    per-token entropy equals *tau*.  A good starting point is the median
113    entropy over your dataset's decoding steps."""
114
115    s: float = 0.3
116    """Smoothing factor s > 0 for the gate sigmoid.  Smaller values make the
117    gate switch more sharply."""
118
119    eps: float = 1e-12
120    """Small constant for numerical stability in log and division operations."""
121
122    prior: Literal["uniform", "topk", "renorm"] = "renorm"
123    """Which conditioning prior *q* to use:
124
125    - `"uniform"` — uniform mass over all green tokens.
126    - `"topk"` — uniform mass over the *topk* green tokens with the highest
127      probability under the current model distribution.
128    - `"renorm"` — renormalise the current model distribution restricted to
129      green tokens (i.e. ``q_i ∝ p_i`` for i ∈ G).
130    """
131
132    topk: int = 16
133    """Number of green candidates to keep when ``prior="topk"``.  Clamped to
134    the number of valid green tokens at runtime."""

Configuration for UGLD_Towards.

UGLDTowardsConfig( green_token_ids: Sequence[int], alpha_max: float = 0.25, tau: float = 3.0, s: float = 0.3, eps: float = 1e-12, prior: Literal['uniform', 'topk', 'renorm'] = 'renorm', topk: int = 16)
green_token_ids: Sequence[int]

Token ids that form the green vocabulary — the set of tokens the model is encouraged to generate. Duplicates and out-of-range ids are ignored at runtime.

alpha_max: float = 0.25

Maximum mixing coefficient α ∈ [0, 1]. The effective α at each step is alpha_max * φ(p), so the actual intervention is always at most alpha_max.

tau: float = 3.0

Entropy threshold τ for the gate φ. The gate is ~0.5 when the per-token entropy equals tau. A good starting point is the median entropy over your dataset's decoding steps.

s: float = 0.3

Smoothing factor s > 0 for the gate sigmoid. Smaller values make the gate switch more sharply.

eps: float = 1e-12

Small constant for numerical stability in log and division operations.

prior: Literal['uniform', 'topk', 'renorm'] = 'renorm'

Which conditioning prior q to use:

  • "uniform" — uniform mass over all green tokens.
  • "topk" — uniform mass over the topk green tokens with the highest probability under the current model distribution.
  • "renorm" — renormalise the current model distribution restricted to green tokens (i.e. q_i ∝ p_i for i ∈ G).
topk: int = 16

Number of green candidates to keep when prior="topk". Clamped to the number of valid green tokens at runtime.

@dataclass(frozen=True)
class UGLDAgainstConfig:
294@dataclass(frozen=True)
295class UGLDAgainstConfig:
296    """Configuration for :class:`UGLD_Against`."""
297
298    red_token_ids: Sequence[int]
299    """Token ids that form the *red* vocabulary — the set of tokens the model
300    is discouraged from generating.  Duplicates and out-of-range ids are
301    ignored at runtime."""
302
303    lambda_max: float = 4.0
304    """Maximum logit penalty λ ≥ 0.  The effective penalty at each step is
305    ``lambda_max * φ(p)``, so stronger penalties require higher
306    *lambda_max*."""
307
308    tau: float = 3.0
309    """Entropy threshold τ for the gate φ.  See :class:`UGLDTowardsConfig`
310    for guidance on choosing this value."""
311
312    s: float = 0.3
313    """Smoothing factor s > 0 for the gate sigmoid.  Smaller values make the
314    gate switch more sharply."""
315
316    eps: float = 1e-12
317    """Small constant for numerical stability in log and division operations."""
318
319    weights: Literal["fixed", "dynamic_minmax"] = "fixed"
320    """How to assign per-token penalty weights within the red vocabulary:
321
322    - `"fixed"` — every red token receives the same penalty weight *fixed_r*.
323    - `"dynamic_minmax"` — penalty weights are proportional to the model's
324      current probability for each red token, with min-max normalisation
325      mapping the range to [1, 2].  Tokens the model is most likely to produce
326      receive the heaviest penalty.
327    """
328
329    fixed_r: float = 1.0
330    """Penalty weight applied to each red token when ``weights="fixed"``.
331    Must be strictly positive."""

Configuration for UGLD_Against.

UGLDAgainstConfig( red_token_ids: Sequence[int], lambda_max: float = 4.0, tau: float = 3.0, s: float = 0.3, eps: float = 1e-12, weights: Literal['fixed', 'dynamic_minmax'] = 'fixed', fixed_r: float = 1.0)
red_token_ids: Sequence[int]

Token ids that form the red vocabulary — the set of tokens the model is discouraged from generating. Duplicates and out-of-range ids are ignored at runtime.

lambda_max: float = 4.0

Maximum logit penalty λ ≥ 0. The effective penalty at each step is lambda_max * φ(p), so stronger penalties require higher lambda_max.

tau: float = 3.0

Entropy threshold τ for the gate φ. See UGLDTowardsConfig for guidance on choosing this value.

s: float = 0.3

Smoothing factor s > 0 for the gate sigmoid. Smaller values make the gate switch more sharply.

eps: float = 1e-12

Small constant for numerical stability in log and division operations.

weights: Literal['fixed', 'dynamic_minmax'] = 'fixed'

How to assign per-token penalty weights within the red vocabulary:

  • "fixed" — every red token receives the same penalty weight fixed_r.
  • "dynamic_minmax" — penalty weights are proportional to the model's current probability for each red token, with min-max normalisation mapping the range to [1, 2]. Tokens the model is most likely to produce receive the heaviest penalty.
fixed_r: float = 1.0

Penalty weight applied to each red token when weights="fixed". Must be strictly positive.