cancel
Showing results for 
Search instead for 
Did you mean: 

Reporting Strange behvaior (Error) when analyzing LSTM model with stm32ai (batch size is not 1)

alpaca
Associate

Hello,

When I export the torch.nn.LSTM to ONNX and use onnxsim to simplify(and also eliminate onnx expand layer, which is not supprted) it,

the generated .onnx model shows the following error after analyzing with stm32ai.exe (version 8.1.0)

NOT IMPLEMENTED: Batch size greater than 1 not implemented

I suspect that the transpose layer just before the LSTM misleads the stm32ai's analyzer to see 7 is batch size in the following onnx graph, which is the length of input sequence.

화면 캡처 2023-08-17 151142.jpg

And I wonder if I'm confused or missed the right way to convert the lstm model to onnx.

For reproduction, I submit the code used for creating the model and lstm onnx file.

the version of onnxsim is 0.4.33, pytorch is 2.0.0

 

import torch
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        lstm_output, (hidden, cell) = self.lstm(x)
        output = self.fc(hidden[0])  # Using the last hidden state as output
        return output

# Define hyperparameters
vocab_size = 100  # Adjust as needed
embedding_dim = 32
hidden_dim = 64
output_dim = 10  # Adjust based on your classification task
num_layers = 1
data_len = 7

# Instantiate the model
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, output_dim, num_layers)

example_input = torch.zeros(1, data_len, embedding_dim)

# Print a summary of the model architecture
torch.onnx.export(model, example_input, 'lstm.onnx', verbose=False, do_constant_folding=True, 
                  export_params=True, opset_version=18, input_names = ['input'],   # the model's input names
                  output_names = ['output'])

 

Thanks in advance,

 

0 REPLIES 0