bitorch_engine.utils.model_helper.unflatten_x

bitorch_engine.utils.model_helper.unflatten_x(x: Tensor, shape: list)[source]

Unflattens a 2D tensor back into a 3D tensor using the original shape, reversing the operation performed by flatten_x.

This function is useful for reconstructing the original 3D structure of sequence data after performing operations that require 2D input tensors.

Parameters:
  • x (torch.Tensor) – A 2D tensor with shape [batch_size * seq_length, output_size].

  • shape (list) – The original shape of the tensor before flattening,

  • [batch_size (as a list) –

  • seq_length].

Returns:

The unflattened 3D tensor with shape [batch_size, seq_length, output_size].

Return type:

torch.Tensor