ugld
UGLD: Uncertainty-Gated Lexical Decoding
Public API:
- UGLD_Towards, UGLD_Against (HuggingFace LogitsProcessor)
- UGLDTowardsConfig, UGLDAgainstConfig (configuration dataclasses)
1""" 2UGLD: Uncertainty-Gated Lexical Decoding 3 4Public API: 5- UGLD_Towards, UGLD_Against (HuggingFace LogitsProcessor) 6- UGLDTowardsConfig, UGLDAgainstConfig (configuration dataclasses) 7""" 8 9from .ugld import ( 10 UGLD_Towards, 11 UGLD_Against, 12 UGLDTowardsConfig, 13 UGLDAgainstConfig, 14) 15 16__all__ = [ 17 "UGLD_Towards", 18 "UGLD_Against", 19 "UGLDTowardsConfig", 20 "UGLDAgainstConfig", 21] 22 23__version__ = "1.0.0"
137class UGLD_Towards(LogitsProcessor): 138 """Condition generation *towards* a predefined vocabulary (UGLD-t). 139 140 At each decoding step the model's next-token distribution *p* is mixed 141 with a conditioning prior *q* that concentrates probability mass on the 142 *green* tokens. The mixing strength is gated by the model's predictive 143 uncertainty, measured via Shannon entropy, so that intervention is strong 144 when the model is uncertain and negligible when it is confident. 145 146 Formally, at each step: 147 148 .. code-block:: text 149 150 p = SoftMax(z) # current model distribution 151 H = -Σ p_i log p_i # Shannon entropy 152 φ = σ((H - τ) / s) # uncertainty gate ∈ (0, 1) 153 α = α_max · φ # effective mixing coefficient 154 p' = (1 − α) p + α q # conditioned distribution 155 156 The output is ``log(p')``; because ``SoftMax(log(p')) = p'``, this is a 157 valid drop-in replacement for the raw logits expected by the HuggingFace 158 generation pipeline. 159 160 Args: 161 config: A :class:`UGLDTowardsConfig` instance specifying the green 162 vocabulary and all hyperparameters. 163 164 Raises: 165 ValueError: If ``config.alpha_max`` is outside ``[0, 1]`` or 166 ``config.topk`` is not positive. 167 168 Example:: 169 170 from transformers import LogitsProcessorList 171 from ugld import UGLD_Towards, UGLDTowardsConfig 172 173 processor = UGLD_Towards(UGLDTowardsConfig( 174 green_token_ids=green_ids, 175 alpha_max=0.5, 176 tau=1.0, 177 s=0.3, 178 prior="renorm", 179 )) 180 out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor])) 181 """ 182 183 def __init__(self, config: UGLDTowardsConfig): 184 super().__init__() 185 if not (0.0 <= config.alpha_max <= 1.0): 186 raise ValueError("alpha_max must be in [0, 1].") 187 if config.topk <= 0: 188 raise ValueError("topk must be > 0.") 189 self.cfg = config 190 191 # Cache for uniform q (depends on vocab size/device/dtype). 192 self._uniform_q: Optional[torch.Tensor] = None 193 self._uniform_meta = None # (V, device, dtype) 194 195 def _uniform_prior(self, V: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: 196 """Build (and cache) a uniform prior over the green vocabulary. 197 198 Returns a 1-D tensor of shape ``[V]`` where each green token receives 199 probability ``1 / |G|`` and all other tokens receive ``0``. The 200 result is cached and reused as long as *V*, *device*, and *dtype* 201 remain unchanged. 202 203 Args: 204 V: Vocabulary size. 205 device: Target device. 206 dtype: Target floating-point dtype. 207 208 Returns: 209 Prior tensor of shape ``[V]``. 210 """ 211 meta = (V, device, dtype) 212 if self._uniform_q is not None and self._uniform_meta == meta: 213 return self._uniform_q 214 215 green = _valid_token_ids(self.cfg.green_token_ids, V, device) 216 q = torch.zeros((V,), dtype=dtype, device=device) 217 if green.numel() > 0: 218 q[green] = 1.0 / float(green.numel()) 219 220 self._uniform_q = q 221 self._uniform_meta = meta 222 return q 223 224 @torch.no_grad() 225 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 226 """Apply UGLD-t to a batch of logits. 227 228 Args: 229 input_ids: Previously generated token ids, shape ``[B, T]``. 230 Not used directly but required by the HuggingFace 231 ``LogitsProcessor`` interface. 232 scores: Raw logits produced by the model, shape ``[B, V]``. 233 234 Returns: 235 Modified log-probabilities of shape ``[B, V]``. Applying 236 ``SoftMax`` to the output yields the conditioned distribution 237 ``p' = (1 − α) p + α q``. If no valid green tokens exist the 238 original *scores* are returned unchanged. 239 """ 240 B, V = scores.shape 241 device, dtype = scores.device, scores.dtype 242 243 # Base distribution p over vocab. 244 p = torch.softmax(scores, dim=-1) # [B, V] 245 246 # Entropy gate. 247 H = _entropy_from_probs(p, self.cfg.eps) # [B] 248 phi = _gate_from_entropy(H, self.cfg.tau, self.cfg.s) # [B] in [0,1] 249 alpha = (self.cfg.alpha_max * phi).clamp(0.0, 1.0) # [B] 250 251 # Valid green token ids in this vocab. 252 green = _valid_token_ids(self.cfg.green_token_ids, V, device) 253 if green.numel() == 0: 254 # No valid green tokens -> no-op. 255 return scores 256 257 # Build q: [B, V] 258 if self.cfg.prior == "uniform": 259 q = self._uniform_prior(V, device, dtype).unsqueeze(0).expand(B, V) 260 261 else: 262 q = torch.zeros_like(p) # [B, V] 263 mass_g = p.index_select(dim=-1, index=green) # [B, |G|] 264 265 if self.cfg.prior == "renorm": 266 denom = mass_g.sum(dim=-1, keepdim=True).clamp_min(self.cfg.eps) 267 q_g = mass_g / denom # [B, |G|] 268 269 elif self.cfg.prior == "topk": 270 k = min(self.cfg.topk, green.numel()) 271 _, idx = torch.topk(mass_g, k=k, dim=-1) # [B, k] indices into green 272 chosen = green.index_select(0, idx.reshape(-1)).reshape(B, k) # [B, k] 273 # Uniform over chosen top-k. 274 q.scatter_( 275 dim=-1, 276 index=chosen, 277 src=torch.full((B, k), 1.0 / float(k), device=device, dtype=dtype), 278 ) 279 # Done. 280 p_prime = (1.0 - alpha.unsqueeze(-1)) * p + alpha.unsqueeze(-1) * q 281 return (p_prime.clamp_min(self.cfg.eps)).log() 282 283 else: 284 raise ValueError(f"Unknown prior='{self.cfg.prior}'. Use: uniform|topk|renorm") 285 286 # Scatter q_g back into vocab positions. 287 q.scatter_(dim=-1, index=green.unsqueeze(0).expand(B, -1), src=q_g) 288 289 # Mix in probability space (convex combination). 290 p_prime = (1.0 - alpha.unsqueeze(-1)) * p + alpha.unsqueeze(-1) * q 291 return (p_prime.clamp_min(self.cfg.eps)).log()
Condition generation towards a predefined vocabulary (UGLD-t).
At each decoding step the model's next-token distribution p is mixed with a conditioning prior q that concentrates probability mass on the green tokens. The mixing strength is gated by the model's predictive uncertainty, measured via Shannon entropy, so that intervention is strong when the model is uncertain and negligible when it is confident.
Formally, at each step:
p = SoftMax(z) # current model distribution
H = -Σ p_i log p_i # Shannon entropy
φ = σ((H - τ) / s) # uncertainty gate ∈ (0, 1)
α = α_max · φ # effective mixing coefficient
p' = (1 − α) p + α q # conditioned distribution
The output is log(p'); because SoftMax(log(p')) = p', this is a
valid drop-in replacement for the raw logits expected by the HuggingFace
generation pipeline.
Args:
config: A UGLDTowardsConfig instance specifying the green
vocabulary and all hyperparameters.
Raises:
ValueError: If config.alpha_max is outside [0, 1] or
config.topk is not positive.
Example::
from transformers import LogitsProcessorList
from ugld import UGLD_Towards, UGLDTowardsConfig
processor = UGLD_Towards(UGLDTowardsConfig(
green_token_ids=green_ids,
alpha_max=0.5,
tau=1.0,
s=0.3,
prior="renorm",
))
out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
183 def __init__(self, config: UGLDTowardsConfig): 184 super().__init__() 185 if not (0.0 <= config.alpha_max <= 1.0): 186 raise ValueError("alpha_max must be in [0, 1].") 187 if config.topk <= 0: 188 raise ValueError("topk must be > 0.") 189 self.cfg = config 190 191 # Cache for uniform q (depends on vocab size/device/dtype). 192 self._uniform_q: Optional[torch.Tensor] = None 193 self._uniform_meta = None # (V, device, dtype)
334class UGLD_Against(LogitsProcessor): 335 """Condition generation *against* a predefined vocabulary (UGLD-a). 336 337 At each decoding step a penalty is subtracted from the logits of *red* 338 tokens. The penalty strength is gated by the model's predictive 339 uncertainty so that suppression is strong when the model is uncertain and 340 negligible when it is confident. Because the penalty is applied in logit 341 space, the output remains unnormalised logits and can be passed directly 342 to subsequent processors or sampling routines. 343 344 Formally, at each step: 345 346 .. code-block:: text 347 348 p = SoftMax(z) # current model distribution 349 H = -Σ p_i log p_i # Shannon entropy 350 φ = σ((H - τ) / s) # uncertainty gate ∈ (0, 1) 351 λ = λ_max · φ # effective penalty strength 352 z' = z − λ r # penalised logits 353 354 where *r* is a non-negative weight vector supported on the red tokens. 355 356 Args: 357 config: A :class:`UGLDAgainstConfig` instance specifying the red 358 vocabulary and all hyperparameters. 359 360 Raises: 361 ValueError: If ``config.lambda_max < 0`` or ``config.fixed_r <= 0``. 362 363 Example:: 364 365 from transformers import LogitsProcessorList 366 from ugld import UGLD_Against, UGLDAgainstConfig 367 368 processor = UGLD_Against(UGLDAgainstConfig( 369 red_token_ids=red_ids, 370 lambda_max=4.0, 371 tau=1.0, 372 s=0.3, 373 weights="fixed", 374 )) 375 out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor])) 376 """ 377 378 def __init__(self, config: UGLDAgainstConfig): 379 super().__init__() 380 if config.lambda_max < 0: 381 raise ValueError("lambda_max must be >= 0.") 382 if config.fixed_r <= 0: 383 raise ValueError("fixed_r must be > 0 (weights must be positive on red tokens).") 384 self.cfg = config 385 386 @torch.no_grad() 387 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 388 """Apply UGLD-a to a batch of logits. 389 390 Args: 391 input_ids: Previously generated token ids, shape ``[B, T]``. 392 Not used directly but required by the HuggingFace 393 ``LogitsProcessor`` interface. 394 scores: Raw logits produced by the model, shape ``[B, V]``. 395 396 Returns: 397 Penalised logits of shape ``[B, V]``, equal to 398 ``z' = z − λ r``. If no valid red tokens exist, or if 399 ``lambda_max`` is zero, the original *scores* are returned 400 unchanged. 401 """ 402 B, V = scores.shape 403 device = scores.device 404 405 red = _valid_token_ids(self.cfg.red_token_ids, V, device) 406 if red.numel() == 0 or self.cfg.lambda_max == 0.0: 407 return scores 408 409 # Compute entropy gate from current model distribution p. 410 p = torch.softmax(scores, dim=-1) # [B, V] 411 H = _entropy_from_probs(p, self.cfg.eps) # [B] 412 phi = _gate_from_entropy(H, self.cfg.tau, self.cfg.s) # [B] 413 lam = (self.cfg.lambda_max * phi).clamp_min(0.0) # [B] 414 415 # Build r: [B, V] 416 r = torch.zeros_like(scores) 417 418 if self.cfg.weights == "fixed": 419 r[:, red] = self.cfg.fixed_r 420 421 elif self.cfg.weights == "dynamic_minmax": 422 pr = p.index_select(dim=-1, index=red) # [B, |R|] 423 pr_min = pr.min(dim=-1, keepdim=True).values 424 pr_max = pr.max(dim=-1, keepdim=True).values 425 f = (pr - pr_min) / (pr_max - pr_min + self.cfg.eps) # [B, |R|] in [0,1] 426 r_r = 1.0 + f # [B, |R|] in [1,2] 427 r.scatter_(dim=-1, index=red.unsqueeze(0).expand(B, -1), src=r_r) 428 429 else: 430 raise ValueError(f"Unknown weights='{self.cfg.weights}'. Use: fixed|dynamic_minmax") 431 432 # Penalize in logit space: z' = z - λ r 433 return scores - lam.unsqueeze(-1) * r
Condition generation against a predefined vocabulary (UGLD-a).
At each decoding step a penalty is subtracted from the logits of red tokens. The penalty strength is gated by the model's predictive uncertainty so that suppression is strong when the model is uncertain and negligible when it is confident. Because the penalty is applied in logit space, the output remains unnormalised logits and can be passed directly to subsequent processors or sampling routines.
Formally, at each step:
p = SoftMax(z) # current model distribution
H = -Σ p_i log p_i # Shannon entropy
φ = σ((H - τ) / s) # uncertainty gate ∈ (0, 1)
λ = λ_max · φ # effective penalty strength
z' = z − λ r # penalised logits
where r is a non-negative weight vector supported on the red tokens.
Args:
config: A UGLDAgainstConfig instance specifying the red
vocabulary and all hyperparameters.
Raises:
ValueError: If config.lambda_max < 0 or config.fixed_r <= 0.
Example::
from transformers import LogitsProcessorList
from ugld import UGLD_Against, UGLDAgainstConfig
processor = UGLD_Against(UGLDAgainstConfig(
red_token_ids=red_ids,
lambda_max=4.0,
tau=1.0,
s=0.3,
weights="fixed",
))
out = model.generate(**inputs, logits_processor=LogitsProcessorList([processor]))
96@dataclass(frozen=True) 97class UGLDTowardsConfig: 98 """Configuration for :class:`UGLD_Towards`.""" 99 100 green_token_ids: Sequence[int] 101 """Token ids that form the *green* vocabulary — the set of tokens the model 102 is encouraged to generate. Duplicates and out-of-range ids are ignored at 103 runtime.""" 104 105 alpha_max: float = 0.25 106 """Maximum mixing coefficient α ∈ [0, 1]. The effective α at each step is 107 ``alpha_max * φ(p)``, so the actual intervention is always at most 108 *alpha_max*.""" 109 110 tau: float = 3.0 111 """Entropy threshold τ for the gate φ. The gate is ~0.5 when the 112 per-token entropy equals *tau*. A good starting point is the median 113 entropy over your dataset's decoding steps.""" 114 115 s: float = 0.3 116 """Smoothing factor s > 0 for the gate sigmoid. Smaller values make the 117 gate switch more sharply.""" 118 119 eps: float = 1e-12 120 """Small constant for numerical stability in log and division operations.""" 121 122 prior: Literal["uniform", "topk", "renorm"] = "renorm" 123 """Which conditioning prior *q* to use: 124 125 - `"uniform"` — uniform mass over all green tokens. 126 - `"topk"` — uniform mass over the *topk* green tokens with the highest 127 probability under the current model distribution. 128 - `"renorm"` — renormalise the current model distribution restricted to 129 green tokens (i.e. ``q_i ∝ p_i`` for i ∈ G). 130 """ 131 132 topk: int = 16 133 """Number of green candidates to keep when ``prior="topk"``. Clamped to 134 the number of valid green tokens at runtime."""
Configuration for UGLD_Towards.
Token ids that form the green vocabulary — the set of tokens the model is encouraged to generate. Duplicates and out-of-range ids are ignored at runtime.
Maximum mixing coefficient α ∈ [0, 1]. The effective α at each step is
alpha_max * φ(p), so the actual intervention is always at most
alpha_max.
Entropy threshold τ for the gate φ. The gate is ~0.5 when the per-token entropy equals tau. A good starting point is the median entropy over your dataset's decoding steps.
Smoothing factor s > 0 for the gate sigmoid. Smaller values make the gate switch more sharply.
Which conditioning prior q to use:
"uniform"— uniform mass over all green tokens."topk"— uniform mass over the topk green tokens with the highest probability under the current model distribution."renorm"— renormalise the current model distribution restricted to green tokens (i.e.q_i ∝ p_ifor i ∈ G).
294@dataclass(frozen=True) 295class UGLDAgainstConfig: 296 """Configuration for :class:`UGLD_Against`.""" 297 298 red_token_ids: Sequence[int] 299 """Token ids that form the *red* vocabulary — the set of tokens the model 300 is discouraged from generating. Duplicates and out-of-range ids are 301 ignored at runtime.""" 302 303 lambda_max: float = 4.0 304 """Maximum logit penalty λ ≥ 0. The effective penalty at each step is 305 ``lambda_max * φ(p)``, so stronger penalties require higher 306 *lambda_max*.""" 307 308 tau: float = 3.0 309 """Entropy threshold τ for the gate φ. See :class:`UGLDTowardsConfig` 310 for guidance on choosing this value.""" 311 312 s: float = 0.3 313 """Smoothing factor s > 0 for the gate sigmoid. Smaller values make the 314 gate switch more sharply.""" 315 316 eps: float = 1e-12 317 """Small constant for numerical stability in log and division operations.""" 318 319 weights: Literal["fixed", "dynamic_minmax"] = "fixed" 320 """How to assign per-token penalty weights within the red vocabulary: 321 322 - `"fixed"` — every red token receives the same penalty weight *fixed_r*. 323 - `"dynamic_minmax"` — penalty weights are proportional to the model's 324 current probability for each red token, with min-max normalisation 325 mapping the range to [1, 2]. Tokens the model is most likely to produce 326 receive the heaviest penalty. 327 """ 328 329 fixed_r: float = 1.0 330 """Penalty weight applied to each red token when ``weights="fixed"``. 331 Must be strictly positive."""
Configuration for UGLD_Against.
Token ids that form the red vocabulary — the set of tokens the model is discouraged from generating. Duplicates and out-of-range ids are ignored at runtime.
Maximum logit penalty λ ≥ 0. The effective penalty at each step is
lambda_max * φ(p), so stronger penalties require higher
lambda_max.
Entropy threshold τ for the gate φ. See UGLDTowardsConfig
for guidance on choosing this value.
Smoothing factor s > 0 for the gate sigmoid. Smaller values make the gate switch more sharply.
How to assign per-token penalty weights within the red vocabulary:
"fixed"— every red token receives the same penalty weight fixed_r."dynamic_minmax"— penalty weights are proportional to the model's current probability for each red token, with min-max normalisation mapping the range to [1, 2]. Tokens the model is most likely to produce receive the heaviest penalty.