2025-03-22 9:53 PM
I have provided my ONNX model, which has been tested.
INTERNAL ERROR: H not found in shape with shape map (BATCH, CH, W)
Solved! Go to Solution.
2025-03-26 3:28 AM
Hello @Nephalem,
Your issue here comes from Einsum layer.
In your case, the operation that does not work is einsum("bhqk,bkhd->bqhd")
It seems that we don't fully support Einsum currently, so you need to replace the einsum layers (you have multiple ones in your model by simple matrixes operation instead). for example:
Equivalent Operations:
Instead of einsum("bhqk,bkhd->bqhd", A, B), use:
import torch
# Example tensors
A = torch.randn(batch, heads, query, key) # (b, h, q, k)
B = torch.randn(batch, key, heads, dim) # (b, k, h, d)
# Transpose B to (b, h, k, d) so that k aligns for matmul
B_transposed = B.permute(0, 2, 1, 3) # (b, h, k, d)
# Perform batched matrix multiplication
result = torch.matmul(A, B_transposed) # (b, h, q, d)
# Swap axes to match expected output shape (b, q, h, d)
result = result.permute(0, 2, 1, 3) # (b, q, h, d)
Have a good day,
Julian
2025-03-24 8:51 AM
Hello @Nephalem ,
There is probably something wrong happening during the conversion of the model because of the original shape of your input.
I'll take a look an update you once I know more.
Have a good day,
Julian
2025-03-26 3:28 AM
Hello @Nephalem,
Your issue here comes from Einsum layer.
In your case, the operation that does not work is einsum("bhqk,bkhd->bqhd")
It seems that we don't fully support Einsum currently, so you need to replace the einsum layers (you have multiple ones in your model by simple matrixes operation instead). for example:
Equivalent Operations:
Instead of einsum("bhqk,bkhd->bqhd", A, B), use:
import torch
# Example tensors
A = torch.randn(batch, heads, query, key) # (b, h, q, k)
B = torch.randn(batch, key, heads, dim) # (b, k, h, d)
# Transpose B to (b, h, k, d) so that k aligns for matmul
B_transposed = B.permute(0, 2, 1, 3) # (b, h, k, d)
# Perform batched matrix multiplication
result = torch.matmul(A, B_transposed) # (b, h, q, d)
# Swap axes to match expected output shape (b, q, h, d)
result = result.permute(0, 2, 1, 3) # (b, q, h, d)
Have a good day,
Julian