cancel
Showing results for 
Search instead for 
Did you mean: 

Problem with LSTM from pytorch.

ASalc.1
Associate III

Hi,

I am trying to deploy a network based on LSTM created with pytorch. The model is analysed correctly but validation on Desktop gives the following error:

 

LOAD ERROR: exception: access violation reading 0x0000000000000004

 

In the MCU a hardfault raises at inference.

I noticed that the error happens because I am accessing the 'hidden_state' of the first LSTM layer to feed it to the second one.

Any ideas about why is this happening?

This is a simple version of the model that reproduces the problem:

 

import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, (hidden_state, _) = self.lstm1(x)
        out, _ = self.lstm2(hidden_state)
        out = self.linear(out[:, -1, :])
        return out

if __name__ == '__main__':
    model = LSTMModel(14, 8, 14)
    input_data = torch.randn(1, 1, 14)
    _ = model(input_data)

    output_file = 'test-model.onnx'
    model.eval()
    dummy_input = torch.randn(1, 1, 14)
    torch.onnx.export(
        model,
        dummy_input,
        output_file,
        verbose=False)

 

0 REPLIES 0