bitorch_engine.utils.model_helper.flatten_x
- bitorch_engine.utils.model_helper.flatten_x(x: Tensor)[source]
Flattens a 3D tensor into a 2D tensor by combining the first two dimensions.
This is particularly useful for processing sequences in models like BERT/Transformers, where you might need to apply operations that expect 2D inputs.
- Parameters:
x (torch.Tensor) – A 3D tensor with shape [batch_size, seq_length, hidden_size].
- Returns:
A tuple containing the flattened 2D tensor with shape [batch_size * seq_length, hidden_size] and the original shape as a list [batch_size, seq_length] for later unflattening.
- Return type:
tuple[torch.Tensor, list]