cancel
Showing results for 
Search instead for 
Did you mean: 

MNIST implementation in stm32h743 issue

yiulsup
Associate

1. this is my python script to build weight for MNIST.

import tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train / 255
x_test = x_test / 255

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

model = models.Sequential([
layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.25),
layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(x_train, y_train, epochs=50, batch_size=64, validation_split=0.1)
model.save("mnist_cnn_model")

converter = tf.lite.TFLiteConverter.from_saved_model("mnist_cnn_model")
tflite_model = converter.convert()

with open("mnist_cnn_model.tflite", "wb") as f:
f.write(tflite_model)

 

2. load it into the x-cube-ai and i get the files as below.
app_x-cube-ai.c network_config.h network_data_params.c network.h
app_x-cube-ai.h network_data.c network_data_params.h
network.c network_data.h network_generate_report.txt

 

3. in app_x-cube-ai.c, i code as below.

 
const uint8_t mnist_digit_4[28*28] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 232, 253, 253, 95, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 86, 46, 0, 0, 0, 0, 0, 0, 91, 246, 252, 232, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 103, 252, 187, 13, 0, 0, 0, 0, 22, 219, 252, 252, 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 8, 181, 252, 246, 30, 0, 0, 0, 0, 65, 252, 237, 197, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 87, 0, 0, 0, 13, 172, 252, 252, 104, 0, 0, 0, 0, 5, 184, 252, 67, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 172, 252, 248, 145, 14, 0, 0, 0, 0, 109, 252, 183, 137, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 224, 252, 248, 134, 0, 0, 0, 0, 0, 53, 238, 252, 245, 86, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 174, 252, 223, 88, 0, 0, 0, 0, 0, 0, 209, 252, 252, 179, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 171, 252, 246, 61, 0, 0, 0, 0, 0, 0, 83, 241, 252, 211, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 252, 252, 249, 220, 220, 215, 111, 192, 220, 221, 243, 252, 252, 149, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 144, 253, 253, 253, 253, 253, 253, 253, 253, 253, 255, 253, 226, 153, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 44, 77, 77, 77, 77, 77, 77, 77, 77, 153, 253, 235, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 74, 214, 240, 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 221, 243, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 180, 252, 119, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 252, 153, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 136, 251, 226, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123, 252, 246, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 165, 252, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 165, 175, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

const uint8_t mnist_digit_5[28*28] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 18, 46, 136, 136, 244, 255, 241, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 94, 163, 253, 253, 253, 253, 238, 218, 204, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 131, 253, 253, 253, 253, 237, 200, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 155, 246, 253, 247, 108, 65, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 207, 253, 253, 230, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 157, 253, 253, 125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 89, 253, 250, 57, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 89, 253, 247, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 89, 253, 247, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 89, 253, 247, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 231, 249, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 225, 253, 231, 213, 213, 123, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 172, 253, 253, 253, 253, 253, 190, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 116, 72, 124, 209, 253, 253, 141, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 219, 253, 206, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 104, 246, 253, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 213, 253, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 226, 253, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 132, 253, 209, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 253, 86, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };



int acquire_and_process_data(float *input_data[]) {

for (int i = 0; i < 28 * 28; i++) {

input_data[0][i] = (float)mnist_digit_4[i] / (float)255.0f;

}

return 0;

}



int post_process(float *output_data) {

float max_val = (float)output_data[0];

int max_index = 0;



for (int i = 1; i < 10; ++i) {

if (output_data[i] > max_val) {

max_val = output_data[i];

max_index = i;

}

}

return max_index;

}

/* USER CODE END 2 */



/* Entry points --------------------------------------------------------------*/



void MX_X_CUBE_AI_Init(void)

{

/* USER CODE BEGIN 5 */

printf("\r\nTEMPLATE - initialization\r\n");



ai_boostrap(data_activations0);

/* USER CODE END 5 */

}



void MX_X_CUBE_AI_Process(void)

{

/* USER CODE BEGIN 6 */

int res = -1;



printf("TEMPLATE - run - main loop\r\n");



if (network) {



do {

/* 1 - acquire and pre-process input data */

res = acquire_and_process_data((float **)data_ins);

/* 2 - process the data - call inference engine */

if (res == 0)

res = ai_run();

/* 3- post-process the predictions */

if (res == 0)

res = post_process((float *)data_outs);

} while (res == 0);

}



if (res) {

ai_error err = { AI_ERROR_INVALID_STATE, AI_ERROR_CODE_NETWORK };

ai_log_err(err, "Process has FAILED");

}

/* USER CODE END 6 */

}

Question

1. using tflie with python, it predict well based on mnist value, however in stm32h, it predict wrong value. where is problem to solve in above my doing.

 

1 REPLY 1
Julian E.
ST Employee

Hello @yiulsup ,

 

Your tflite model and the C model generated by the stedgeai core are different.

You can check if the C model is close to the original model using the validate command and making sure that the COS metric is very close to 1 (a COS of 1 means that the model react the same).

 

Here is the doc about the way to validate a model.

https://stedgeai-dc.st.com/assets/embedded-docs/evaluation_metrics.html 

 

If your COS is bad, it may mean that the conversion is badly done, it may be due to the way the model is quantized or because of the architecture of the model.

 

If you have a good COS, then there is most likely an issue in your main.c code.

 

Have a good day,

Julian


In order to give better visibility on the answered topics, please click on 'Accept as Solution' on the reply which solved your issue or answered your question.