updown.modules.cbs¶
-
class
updown.modules.cbs.
ConstrainedBeamSearch
(end_index: int, max_steps: int = 20, beam_size: int = 5, per_node_beam_size: Optional[int] = None)[source]¶ Bases:
object
Implements Constrained Beam Search for decoding the most likely sequences conditioned on a Finite State Machine with specified state transitions.
Note
We keep the method signatures as close to
BeamSearch
as possible. Most of the docstring is adapted from AllenNLP, so thanks to them!- Parameters
- end_indexint
The index of the
@@BOUNDARY@@
token in the target vocabulary.- max_stepsint, optional (default = 20)
The maximum number of decoding steps to take, i.e. the maximum length of the predicted sequences.
- beam_sizeint, optional (default = 10)
The width of the beam used for each “main state” in the Finite State Machine.
- per_node_beam_sizeint, optional (default = beam_size)
The maximum number of candidates to consider per node, at each step in the search. If not given, this just defaults to
beam_size
. Setting this parameter to a number smaller thanbeam_size
may give better results, as it can introduce more diversity into the search. See Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017.
-
search
(self, start_predictions:torch.Tensor, start_state:Dict[str, torch.Tensor], step:Callable[[torch.Tensor, Dict[str, torch.Tensor]], Tuple[torch.Tensor, Dict[str, torch.Tensor]]], fsm:torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]¶ Given a starting state, a step function, and an FSM adjacency matrix, apply Constrained Beam Search to find most likely target sequences satisfying specified constraints in FSM.
Note
If your step function returns
-inf
for some log probabilities (like if you’re using a masked log-softmax) then some of the “best” sequences returned may also have-inf
log probability. Specifically this happens when the beam size is smaller than the number of actions with finite log probability (non-zero probability) returned by the step function. Therefore if you’re using a mask you may want to check the results fromsearch
and potentially discard sequences with non-finite log probability.- Parameters
- start_predictionstorch.Tensor
A tensor containing the initial predictions with shape
(batch_size, )
. These are usually just@@BOUNDARY@@
token indices.- start_state
Dict[str, torch.Tensor]
The initial state passed to the
step
function. Each value of the state dict should be a tensor of shape(batch_size, *)
, where*
means any other number of dimensions.- step
StepFunctionType
A function that is responsible for computing the next most likely tokens, given the current state and the predictions from the last time step. The function should accept two arguments. The first being a tensor of shape
(group_size,)
, representing the index of the predicted tokens from the last time step, and the second being the current state. Thegroup_size
will bebatch_size * beam_size * num_fsm_states
except in the initial step, for which it will just bebatch_size
. The function is expected to return a tuple, where the first element is a tensor of shape(group_size, vocab_size)
containing the log probabilities of the tokens for the next step, and the second element is the updated state. The tensor in the state should have shape(group_size, *)
, where*
means any other number of dimensions.
- Returns
- Tuple[torch.Tensor, torch.Tensor]
Tuple of
(predictions, log_probabilities)
, wherepredictions
has shape(batch_size, num_fsm_states, beam_size, max_steps)
andlog_probabilities
has shape(batch_size, num_fsm_states, beam_size)
.