bitorch_engine.utils.model_helper.flatten_x

bitorch_engine.utils.model_helper.flatten_x(x: Tensor)[source]

Flattens a 3D tensor into a 2D tensor by combining the first two dimensions.

This is particularly useful for processing sequences in models like BERT/Transformers, where you might need to apply operations that expect 2D inputs.

Parameters:

x (torch.Tensor) – A 3D tensor with shape [batch_size, seq_length, hidden_size].

Returns:

A tuple containing the flattened 2D tensor with shape [batch_size * seq_length, hidden_size] and the original shape as a list [batch_size, seq_length] for later unflattening.

Return type:

tuple[torch.Tensor, list]