bitorch_engine.utils.model_helper.unflatten_x
- bitorch_engine.utils.model_helper.unflatten_x(x: Tensor, shape: list)[source]
Unflattens a 2D tensor back into a 3D tensor using the original shape, reversing the operation performed by flatten_x.
This function is useful for reconstructing the original 3D structure of sequence data after performing operations that require 2D input tensors.
- Parameters:
x (torch.Tensor) – A 2D tensor with shape [batch_size * seq_length, output_size].
shape (list) – The original shape of the tensor before flattening,
[batch_size (as a list) –
seq_length]. –
- Returns:
The unflattened 3D tensor with shape [batch_size, seq_length, output_size].
- Return type:
torch.Tensor