anycores commited on
Commit
0758c16
1 Parent(s): a16a836

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.cpp +13 -4
  2. xg_runtime_api.h +17 -3
main.cpp CHANGED
@@ -22,7 +22,6 @@ void test_whisper(const std::string& weight_path, const std::string& input_path)
22
  xg_get_model_info(&minfo);
23
  std::cout << minfo.model_name << " " << minfo.model_version << std::endl;
24
 
25
- std::cout << "initing graph" << std::endl;
26
  XgGraph* graph = nullptr;
27
  if (xg_init_graph(weight_path, XGWeightSource::XG_ONNX, &graph) != XGResult::XG_SUCCESS)
28
  {
@@ -46,7 +45,10 @@ void test_whisper(const std::string& weight_path, const std::string& input_path)
46
  }
47
 
48
  // load the data into XgData
49
- reinterpret_cast<std::string*>(input_data->raw_data)[0] = input_path;
 
 
 
50
 
51
  if (xg_set_input_data(graph, 0, input_data) != XGResult::XG_SUCCESS)
52
  {
@@ -74,6 +76,13 @@ void test_whisper(const std::string& weight_path, const std::string& input_path)
74
  }
75
 
76
  // print output
77
- std::string* o1 = reinterpret_cast<std::string*>(output_data->raw_data);
78
- std::cout << o1[0] << std::endl;
 
 
 
 
 
 
 
79
  }
 
22
  xg_get_model_info(&minfo);
23
  std::cout << minfo.model_name << " " << minfo.model_version << std::endl;
24
 
 
25
  XgGraph* graph = nullptr;
26
  if (xg_init_graph(weight_path, XGWeightSource::XG_ONNX, &graph) != XGResult::XG_SUCCESS)
27
  {
 
45
  }
46
 
47
  // load the data into XgData
48
+ xg_copy_stdstrings_to_data(
49
+ std::vector<std::string>{input_path},
50
+ input_data
51
+ );
52
 
53
  if (xg_set_input_data(graph, 0, input_data) != XGResult::XG_SUCCESS)
54
  {
 
76
  }
77
 
78
  // print output
79
+ std::vector<std::string> texts;
80
+ size_t num_texts = xg_get_num_of_strings(output_data);
81
+ xg_copy_data_to_stdstrings(num_texts, output_data, texts);
82
+ std::cout << texts[0] << std::endl;
83
+
84
+ // clean up
85
+ xg_destroy_data(&input_data);
86
+ xg_destroy_data(&output_data);
87
+ xg_destroy_graph(&graph);
88
  }
xg_runtime_api.h CHANGED
@@ -115,13 +115,14 @@ XG_API XGResult xg_get_output_data(
115
  XgData** data
116
  );
117
  XG_API XGResult xg_set_input_data(
118
- const XgGraph* graph,
119
- const unsigned int input_idx,
120
- const XgData* data
121
  );
122
 
123
  // helper functions
124
  XG_API bool xg_is_data_bool(const XgData* data);
 
125
  XG_API bool xg_is_data_uint8(const XgData* data);
126
  XG_API bool xg_is_data_uint16(const XgData* data);
127
  XG_API bool xg_is_data_uint32(const XgData* data);
@@ -135,4 +136,17 @@ XG_API bool xg_is_data_float16(const XgData* data);
135
  XG_API bool xg_is_data_float32(const XgData* data);
136
  XG_API bool xg_is_data_float64(const XgData* data);
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  #endif // __XG_RUNTIME_API__
 
115
  XgData** data
116
  );
117
  XG_API XGResult xg_set_input_data(
118
+ const XgGraph* graph,
119
+ const unsigned int input_idx,
120
+ const XgData* data
121
  );
122
 
123
  // helper functions
124
  XG_API bool xg_is_data_bool(const XgData* data);
125
+ XG_API bool xg_is_data_string(const XgData* data);
126
  XG_API bool xg_is_data_uint8(const XgData* data);
127
  XG_API bool xg_is_data_uint16(const XgData* data);
128
  XG_API bool xg_is_data_uint32(const XgData* data);
 
136
  XG_API bool xg_is_data_float32(const XgData* data);
137
  XG_API bool xg_is_data_float64(const XgData* data);
138
 
139
+ XG_API void xg_copy_stdstrings_to_data(
140
+ const std::vector<std::string>& ss,
141
+ XgData* data
142
+ );
143
+ XG_API void xg_copy_data_to_stdstrings(
144
+ const size_t length, // number of strings
145
+ const XgData* data,
146
+ std::vector<std::string>& ss
147
+ );
148
+ XG_API size_t xg_get_num_of_strings(
149
+ const XgData* data
150
+ );
151
+
152
  #endif // __XG_RUNTIME_API__