updown.modules.cbs

class updown.modules.cbs.ConstrainedBeamSearch(end_index: int, max_steps: int = 20, beam_size: int = 5, per_node_beam_size: Optional[int] = None)[source]

Bases: object

Implements Constrained Beam Search for decoding the most likely sequences conditioned on a Finite State Machine with specified state transitions.

Note

We keep the method signatures as close to BeamSearch as possible. Most of the docstring is adapted from AllenNLP, so thanks to them!

Parameters
end_indexint

The index of the @@BOUNDARY@@ token in the target vocabulary.

max_stepsint, optional (default = 20)

The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences.

beam_sizeint, optional (default = 10)

The width of the beam used for each “main state” in the Finite State Machine.

per_node_beam_sizeint, optional (default = beam_size)

The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to beam_size. Setting this parameter to a number smaller than beam_size may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017.

search(self, start_predictions:torch.Tensor, start_state:Dict[str, torch.Tensor], step:Callable[[torch.Tensor, Dict[str, torch.Tensor]], Tuple[torch.Tensor, Dict[str, torch.Tensor]]], fsm:torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]

Given a starting state, a step function, and an FSM adjacency matrix, apply Constrained Beam Search to find most likely target sequences satisfying specified constraints in FSM.

Note

If your step function returns -inf for some log probabilities (like if you’re using a masked log-softmax) then some of the “best” sequences returned may also have -inf log probability. Specifically this happens when the beam size is smaller than the number of actions with finite log probability (non-zero probability) returned by the step function. Therefore if you’re using a mask you may want to check the results from search and potentially discard sequences with non-finite log probability.

Parameters
start_predictionstorch.Tensor

A tensor containing the initial predictions with shape (batch_size, ). These are usually just @@BOUNDARY@@ token indices.

start_stateDict[str, torch.Tensor]

The initial state passed to the step function. Each value of the state dict should be a tensor of shape (batch_size, *), where * means any other number of dimensions.

stepStepFunctionType

A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two arguments. The first being a tensor of shape (group_size,), representing the index of the predicted tokens from the last time step, and the second being the current state. The group_size will be batch_size * beam_size * num_fsm_states except in the initial step, for which it will just be batch_size. The function is expected to return a tuple, where the first element is a tensor of shape (group_size, vocab_size) containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape (group_size, *), where * means any other number of dimensions.

Returns
Tuple[torch.Tensor, torch.Tensor]

Tuple of (predictions, log_probabilities), where predictions has shape (batch_size, num_fsm_states, beam_size, max_steps) and log_probabilities has shape (batch_size, num_fsm_states, beam_size).