updown.utils.decoding¶
-
updown.utils.decoding.
select_best_beam
(beams:torch.Tensor, beam_log_probabilities:torch.Tensor) → torch.Tensor[source]¶ Select the best beam out of a set of decoded beams.
- Parameters
- beams: torch.Tensor
A tensor of shape
(batch_size, num_states, max_decoding_steps)
containing decoded beams byBeamSearch
. These beams are sorted according to their likelihood (descending) inbeam_size
dimension.- beam_log_probabilities: torch.Tensor
A tensor of shape
(batch_size, num_states, beam_size)
containing likelihood of decoded beams.
-
updown.utils.decoding.
select_best_beam_with_constraints
(beams:torch.Tensor, beam_log_probabilities:torch.Tensor, given_constraints:torch.Tensor, min_constraints_to_satisfy:int=2) → torch.Tensor[source]¶ Select the best beam which satisfies specified minimum constraints out of a total number of given constraints.
Note
The implementation of this function goes hand-in-hand with the FSM building implementation in
build()
- it defines which state satisfies which (basically, how many) constraints. If the “definition” of states change, then selection of beams also changes accordingly.- Parameters
- beams: torch.Tensor
A tensor of shape
(batch_size, num_states, beam_size, max_decoding_steps)
containing decoded beams byConstrainedBeamSearch
. These beams are sorted according to their likelihood (descending) inbeam_size
dimension.- beam_log_probabilities: torch.Tensor
A tensor of shape
(batch_size, num_states, beam_size)
containing likelihood of decoded beams.- given_constraints: torch.Tensor
A tensor of shape
(batch_size, )
containing number of constraints given at the start of decoding.- min_constraints_to_satisfy: int, optional (default = 2)
Minimum number of constraints to satisfy. This is either 2, or
given_constraints
if they are less than 2. Beams corresponding to states not satisfying at least these number of constraints will be dropped. Only up to 3 supported.
- Returns
- torch.Tensor
Decoded sequence (beam) which has highest likelihood among beams satisfying constraints.