cancel
Showing results for 
Search instead for 
Did you mean: 

cube-ai

masirwlw
Visitor

masirwlw_0-1740736241654.png

 

 

#undef __FPU_PRESENT
#include "stm32h7xx.h"
#include "main.h"
#include "crc.h"
#include "memorymap.h"
#include "gpio.h"
#include "app_x-cube-ai.h"
#include "network.h"
#include "ai_platform.h"
#include "network_data.h"
#include <stdlib.h>
#include <stdio.h>
#include "network_data.h"   // 确保包含权重数据声明



/* 手动定义缺失的宏(必须与network.c中的实际值一致) */
#ifndef AI_NETWORK_WEIGHTS_SIZE
#define AI_NETWORK_WEIGHTS_SIZE    794136  // 根据network.c中的值填写
#endif

#ifndef AI_NETWORK_ACTIVATIONS_SIZE
#define AI_NETWORK_ACTIVATIONS_SIZE 45572  // 根据network.c中的值填写
#endif

/* 输入输出缓冲区定义 */
AI_ALIGNED(32) static float input_data[90 * 3 * 1 * 1] = {0}; // 严格匹配network.c输入形状
AI_ALIGNED(32) static float output_data[1 * 1 * 6 * 1] = {0}; // 匹配输出形状



static ai_handle network_handle = AI_HANDLE_NULL;


/* 函数原型声明 -------------------------------------------------------------*/
void SystemClock_Config(void);
static void MPU_Config(void);
static int argmax(const float *arr, int size);
void Error_Handler(void);

int main(void)
{

  MPU_Config();
  HAL_Init();
  SystemClock_Config();
  MX_GPIO_Init();
  MX_CRC_Init();
  MX_Core_Init();
  printf("AI Model Initialized\n");
  
  /* 1. 创建网络实例 */
    ai_error err = ai_mnetwork_create(
        AI_NETWORK_MODEL_NAME,  // 使用network.h中定义的模型名称
        &network_handle,
        NULL  // 使用默认配置
    );
    
    if (err.type != AI_ERROR_NONE) {
        printf("Network create error: 0x%08X\n", err.code);
        Error_Handler();
    }

  printf("1\n");
    
    
     /* 2. 初始化网络参数 */
    ai_network_params params = {
        .params = {
            .format = AI_BUFFER_FORMAT_U8,
            .n_batches = 1,
            .height = 1,
            .width = AI_NETWORK_WEIGHTS_SIZE,
            .channels = 1,
            .data = AI_HANDLE_PTR(ai_network_data_weights_get())
        },
        .activations = {
            .format = AI_BUFFER_FORMAT_U8,
            .n_batches = 1,
            .height = 1,
            .width = AI_NETWORK_ACTIVATIONS_SIZE,
            .channels = 1,
            .data = AI_HANDLE_PTR(input_data)
        }
    };

    if (!ai_mnetwork_init(network_handle, &params)) {
        printf("Network init failed\n");
        ai_mnetwork_destroy(network_handle);
        Error_Handler();
    }
    
      printf("2\n");
    
   /* 3. 配置输入输出缓冲区 */
    ai_buffer input_buf = AI_BUFFER_OBJ_INIT(
        AI_BUFFER_FORMAT_FLOAT,  // format_
        90,                      // height (根据network.c中的AI_SHAPE_INIT(90,3,1,1))
        3,                       // width
        1,                       // channels
        1,                       // n_batches
        input_data               // data
    );

    ai_buffer output_buf = AI_BUFFER_OBJ_INIT(
        AI_BUFFER_FORMAT_FLOAT,
        1,    // height (根据network.c中的AI_SHAPE_INIT(1,1,6,1))
        1,
        6,
        1,
        output_data
    );
    
     /* 主循环 */
    srand(HAL_GetTick());
    
      printf("3\n");
  
  while (1)
  {
 /* 生成输入数据 */
        for (int i = 0; i < 90 * 3; i++) {
            input_data[i] = (2.0f * rand() / RAND_MAX) - 1.0f;
        }
 printf("4\n");
        /* 缓存维护 */
        SCB_CleanDCache_by_Addr(
            (uint32_t*)((uintptr_t)input_data & ~0x1F),
            sizeof(input_data) + 32
        );

            // 在调用前添加地址验证
printf("ai_network_run实际地址: 0x%08X\n", (uint32_t)ai_network_run);
            
            

            
            
        /* 执行推理 */
        int result = ai_mnetwork_run(network_handle, &input_buf, &output_buf);
             
        if (result != AI_ERROR_NONE) {
            printf("Inference error: %d\n", result);
            Error_Handler();
        }

        /* 处理输出 */
        SCB_InvalidateDCache_by_Addr(
            (uint32_t*)((uintptr_t)output_data & ~0x1F),
            sizeof(output_data) + 32
        );

        int class_id = argmax(output_data, 6);
        printf("Predicted class: %d\n", class_id);
        HAL_Delay(1000);
    }
    

}



/* 实现必须严格匹配声明 */
static int argmax(const float *arr, int size)
{
    int max_idx = 0;
    for (int i = 1; i < size; i++) {
        if (arr[i] > arr[max_idx]) max_idx = i;
    }
    return max_idx;
}


/**
  * @brief System Clock Configuration
  * @retval None
  */
void SystemClock_Config(void)
{
  RCC_OscInitTypeDef RCC_OscInitStruct = {0};
  RCC_ClkInitTypeDef RCC_ClkInitStruct = {0};

  /** Supply configuration update enable
  */
  HAL_PWREx_ConfigSupply(PWR_LDO_SUPPLY);

  /** Configure the main internal regulator output voltage
  */
  __HAL_PWR_VOLTAGESCALING_CONFIG(PWR_REGULATOR_VOLTAGE_SCALE3);

  while(!__HAL_PWR_GET_FLAG(PWR_FLAG_VOSRDY)) {}

  /** Initializes the RCC Oscillators according to the specified parameters
  * in the RCC_OscInitTypeDef structure.
  */
  RCC_OscInitStruct.OscillatorType = RCC_OSCILLATORTYPE_HSI;
  RCC_OscInitStruct.HSIState = RCC_HSI_DIV1;
  RCC_OscInitStruct.HSICalibrationValue = RCC_HSICALIBRATION_DEFAULT;
  RCC_OscInitStruct.PLL.PLLState = RCC_PLL_NONE;
  if (HAL_RCC_OscConfig(&RCC_OscInitStruct) != HAL_OK)
  {
    Error_Handler();
  }

  /** Initializes the CPU, AHB and APB buses clocks
  */
  RCC_ClkInitStruct.ClockType = RCC_CLOCKTYPE_HCLK|RCC_CLOCKTYPE_SYSCLK
                              |RCC_CLOCKTYPE_PCLK1|RCC_CLOCKTYPE_PCLK2
                              |RCC_CLOCKTYPE_D3PCLK1|RCC_CLOCKTYPE_D1PCLK1;
  RCC_ClkInitStruct.SYSCLKSource = RCC_SYSCLKSOURCE_HSI;
  RCC_ClkInitStruct.SYSCLKDivider = RCC_SYSCLK_DIV1;
  RCC_ClkInitStruct.AHBCLKDivider = RCC_HCLK_DIV1;
  RCC_ClkInitStruct.APB3CLKDivider = RCC_APB3_DIV1;
  RCC_ClkInitStruct.APB1CLKDivider = RCC_APB1_DIV1;
  RCC_ClkInitStruct.APB2CLKDivider = RCC_APB2_DIV2;
  RCC_ClkInitStruct.APB4CLKDivider = RCC_APB4_DIV1;

  if (HAL_RCC_ClockConfig(&RCC_ClkInitStruct, FLASH_LATENCY_1) != HAL_OK)
  {
    Error_Handler();
  }
}

/* USER CODE BEGIN 4 */

/* USER CODE END 4 */

 /* MPU Configuration */



void MPU_Config(void)
{
  MPU_Region_InitTypeDef MPU_InitStruct = {0};

  /* Disables the MPU */
  HAL_MPU_Disable();

  /** Initializes and configures the Region and the memory to be protected
  */
  MPU_InitStruct.Enable = MPU_REGION_ENABLE;
  MPU_InitStruct.Number = MPU_REGION_NUMBER0;
  MPU_InitStruct.BaseAddress = 0x24000000;
  MPU_InitStruct.Size = MPU_REGION_SIZE_512KB;
  MPU_InitStruct.SubRegionDisable = 0x87;
  MPU_InitStruct.TypeExtField = MPU_TEX_LEVEL0;
  MPU_InitStruct.AccessPermission =  MPU_REGION_FULL_ACCESS;
  MPU_InitStruct.DisableExec = MPU_INSTRUCTION_ACCESS_DISABLE;
  MPU_InitStruct.IsShareable = MPU_ACCESS_NOT_SHAREABLE;
  MPU_InitStruct.IsCacheable = MPU_ACCESS_CACHEABLE;
  MPU_InitStruct.IsBufferable = MPU_ACCESS_NOT_BUFFERABLE;

  HAL_MPU_ConfigRegion(&MPU_InitStruct);
  /* Enables the MPU */
  HAL_MPU_Enable(MPU_PRIVILEGED_DEFAULT);

}

/**
  * @brief  This function is executed in case of error occurrence.
  * @retval None
  */
void Error_Handler(void)
{
  /* USER CODE BEGIN Error_Handler_Debug */
  /* User can add his own implementation to report the HAL error return state */
  __disable_irq();
  while (1)
  {
  }
  /* USER CODE END Error_Handler_Debug */
}

#ifdef  USE_FULL_ASSERT
/**
  * @brief  Reports the name of the source file and the source line number
  *         where the assert_param error has occurred.
  * @PAram  file: pointer to the source file name
  * @PAram  line: assert_param error line source number
  * @retval None
  */
void assert_failed(uint8_t *file, uint32_t line)
{
  /* USER CODE BEGIN 6 */
  /* User can add his own implementation to report the file name and line number,
     ex: printf("Wrong parameters value: file %s on line %d\r\n", file, line) */
  /* USER CODE END 6 */
}
#endif /* USE_FULL_ASSERT */

 

 

This is my main.c file main loop. I call cubeai to generate code file app-cube-ai.c file inside the function. The model initialization was successful, and the model parameters were successfully matched. However, there was a delay in the model inference. The code for Caton is: int result = ai_mnetwork_run(network_handle, &input_buf, &output_buf); . I went inside the model inference function and found that Caton's code was if (inn)
return inn->entry->ai_run(inn->handle, input, output);
else
return 0; Please help me

0 REPLIES 0