TangJicheng commited on
Commit
7671dfd
1 Parent(s): 19befe6

chore: init

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ None-ResNet-None-CTC.pth filter=lfs diff=lfs merge=lfs -text
37
+ frozen_east_text_detection.pb filter=lfs diff=lfs merge=lfs -text
38
+ frozen_east_text_detection.tar.gz filter=lfs diff=lfs merge=lfs -text
39
+ img1.jpg filter=lfs diff=lfs merge=lfs -text
40
+ *.onnx filter=lfs diff=lfs merge=lfs -text
None-ResNet-None-CTC.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5846635e760137ddb9d6d1485641d802ddaaa987232b1de79f2be78031cf98dc
3
+ size 177272573
README.md CHANGED
@@ -1,3 +1,387 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ #
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions
6
+ # are met:
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright
10
+ # notice, this list of conditions and the following disclaimer in the
11
+ # documentation and/or other materials provided with the distribution.
12
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
13
+ # contributors may be used to endorse or promote products derived
14
+ # from this software without specific prior written permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+ -->
28
+
29
+ # Deploy models using Triton
30
+
31
+ | Navigate to | [Part 2: Improving Resource Utilization](../Part_2-improving_resource_utilization/) | [Documentation: Model Repository](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_repository.md) | [Documentation: Model Configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md) |
32
+ | ------------ | --------------- | --------------- | --------------- |
33
+
34
+ Any deep learning inference serving solution needs to tackle two fundamental challenges:
35
+
36
+ * Managing multiple models.
37
+ * Versioning, loading, and unloading models.
38
+
39
+ ## Before we begin
40
+
41
+ The conceptual guide aims to educate developers about the challenges faced whilst building inference infrastructure for deploying deep learning pipelines. `Part 1 - Part 5` of this guide build towards solving a simple problem: deploying a performant and scalable pipeline for transcribing text from images. This pipeline includes 5 steps:
42
+
43
+ 1. Pre-process the raw image
44
+ 1. Detect which parts of the image contain text (Text Detection Model)
45
+ 1. Crop image to regions with text
46
+ 1. Find text probabilities (Text Recognition Model)
47
+ 1. Convert probabilities to actual text
48
+
49
+ In `Part 1`, we start by deploying both models on Triton with the pre/post processing steps done on the client.
50
+
51
+ ## Deploying multiple models
52
+
53
+ The key challenge around managing multiple models is to build an infrastructure that can cater to the different requirements of different models. For instance, users may need to deploy a PyTorch model and TensorFlow model on the same server, and they have different loads for both the models, need to run them on different hardware devices, and need to independently manage the serving configurations (model queues, versions, caching, acceleration, and more). The Triton Inference Server caters to all of the above and more.
54
+
55
+ ![multiple models](./img/multiple_models.PNG)
56
+
57
+ The first step in deploying models using the Triton Inference Server is building a repository that houses the models which will be served and the configuration schema. For the purposes of this demonstration, we will be making use of an [EAST](https://arxiv.org/pdf/1704.03155v2.pdf) model to detect text and a text recognition model. This workflow is largely an adaptation of [OpenCV's Text Detection](https://docs.opencv.org/4.x/db/da4/samples_2dnn_2text_detection_8cpp-example.html) samples.
58
+
59
+ To begin, let's clone the repository and navigate to this folder.
60
+
61
+ ```bash
62
+ cd Conceptual_Guide/Part_1-model_deployment
63
+ ```
64
+
65
+ Next, we'll be downloading the necessary models and making sure they are in a format that triton can deploy.
66
+
67
+ ### Model 1: Text Detection
68
+
69
+ Download and unzip OpenCV's EAST model.
70
+
71
+ ```bash
72
+ wget https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz
73
+ tar -xvf frozen_east_text_detection.tar.gz
74
+ ```
75
+
76
+ Export to ONNX.
77
+ >Note: The following step requires you to have the TensorFlow library installed. We recommend executing the following step within the NGC TensorFlow container environment, which you can launch with `docker run -it --gpus all -v ${PWD}:/workspace nvcr.io/nvidia/tensorflow:<yy.mm>-tf2-py3`
78
+
79
+ ```bash
80
+ pip install -U tf2onnx
81
+ python -m tf2onnx.convert --input frozen_east_text_detection.pb --inputs "input_images:0" --outputs "feature_fusion/Conv_7/Sigmoid:0","feature_fusion/concat_3:0" --output detection.onnx
82
+ ```
83
+
84
+ ### Model 2: Text Recognition
85
+
86
+ Download the Text Recognition model weights.
87
+
88
+ ```bash
89
+ wget https://www.dropbox.com/sh/j3xmli4di1zuv3s/AABzCC1KGbIRe2wRwa3diWKwa/None-ResNet-None-CTC.pth
90
+ ```
91
+
92
+ Export the models as `.onnx` using the file in the model definition file in the `utils` folder. This file is adapted from [Baek et. al. 2019](https://github.com/clovaai/deep-text-recognition-benchmark).
93
+
94
+ >Note: The following python script requires you to have the PyTorch library installed. We recommend executing the following step within the NGC PyTorch container environment, which you can launch with `docker run -it --gpus all -v ${PWD}:/workspace nvcr.io/nvidia/pytorch:<yy.mm>-py3`
95
+
96
+ ```python
97
+ import torch
98
+ from utils.model import STRModel
99
+
100
+ # Create PyTorch Model Object
101
+ model = STRModel(input_channels=1, output_channels=512, num_classes=37)
102
+
103
+ # Load model weights from external file
104
+ state = torch.load("None-ResNet-None-CTC.pth")
105
+ state = {key.replace("module.", ""): value for key, value in state.items()}
106
+ model.load_state_dict(state)
107
+
108
+ # Create ONNX file by tracing model
109
+ trace_input = torch.randn(1, 1, 32, 100)
110
+ torch.onnx.export(model, trace_input, "str.onnx", verbose=True)
111
+ ```
112
+
113
+ ### Setting up the model repository
114
+
115
+ A [model repository](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html) is Triton's way of reading your models and any associated metadata with each model (configurations, version files, etc.). These model repositories can live in a local or network attatched filesystem, or in a cloud object store like AWS S3, Azure Blob Storage or Google Cloud Storage. For more details on model repository location, refer to [the documentation](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html#model-repository-locations). Servers can use also multiple different model repositories. For simplicity, this explanation only uses a single repository stored in the [local filesystem](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html#local-file-system), in the following format:
116
+
117
+ ```bash
118
+ # Example repository structure
119
+ <model-repository>/
120
+ <model-name>/
121
+ [config.pbtxt]
122
+ [<output-labels-file> ...]
123
+ <version>/
124
+ <model-definition-file>
125
+ <version>/
126
+ <model-definition-file>
127
+ ...
128
+ <model-name>/
129
+ [config.pbtxt]
130
+ [<output-labels-file> ...]
131
+ <version>/
132
+ <model-definition-file>
133
+ <version>/
134
+ <model-definition-file>
135
+ ...
136
+ ...
137
+ ```
138
+
139
+ There are three important components to be discussed from the above structure:
140
+
141
+ * `model-name`: The identifying name for the model.
142
+ * `config.pbtxt`: For each model, users can define a model configuration. This configuration, at minimum, needs to define: the backend, name, shape, and datatype of model inputs and outputs. For most of the popular backends, this configuration file is autogenerated with defaults. The full specification of the configuration file can be found in the [`model_config` protobuf definition](https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto).
143
+ * `version`: versioning makes multiple versions of the same model available for use depending on the policy selected. [More Information about versioning.](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_repository.html#model-versions)
144
+
145
+ For this example you can set up the model repository structure in the following manner:
146
+
147
+ ```bash
148
+ mkdir -p model_repository/text_detection/1
149
+ mv detection.onnx model_repository/text_detection/1/model.onnx
150
+
151
+ mkdir -p model_repository/text_recognition/1
152
+ mv str.onnx model_repository/text_recognition/1/model.onnx
153
+ ```
154
+
155
+ These commands should give you a repository that looks this:
156
+
157
+ ```bash
158
+ # Expected folder layout
159
+ model_repository/
160
+ ├── text_detection
161
+ │ ├── 1
162
+ │ │ └── model.onnx
163
+ │ └── config.pbtxt
164
+ └── text_recognition
165
+ ├── 1
166
+ │ └── model.onnx
167
+ └── config.pbtxt
168
+ ```
169
+
170
+ Note that, for this example, we've already created the `config.pbtxt` files and placed them in the necessary location. In the next section, we'll discuss the contents of these files.
171
+
172
+ ### Model configuration
173
+
174
+ With the models and the file structure ready, the next things we need to look at are the `config.pbtxt` model configuration files. Let's first look at the model configuration for the `EAST text detection` model that's been provided for you at `/model_repository/text_detection/config.pbtxt`. This shows that `text_detection` is an ONNX model that has one `input` and two `output` tensors.
175
+
176
+ ``` text proto
177
+ name: "text_detection"
178
+ backend: "onnxruntime"
179
+ max_batch_size : 256
180
+ input [
181
+ {
182
+ name: "input_images:0"
183
+ data_type: TYPE_FP32
184
+ dims: [ -1, -1, -1, 3 ]
185
+ }
186
+ ]
187
+ output [
188
+ {
189
+ name: "feature_fusion/Conv_7/Sigmoid:0"
190
+ data_type: TYPE_FP32
191
+ dims: [ -1, -1, -1, 1 ]
192
+ }
193
+ ]
194
+ output [
195
+ {
196
+ name: "feature_fusion/concat_3:0"
197
+ data_type: TYPE_FP32
198
+ dims: [ -1, -1, -1, 5 ]
199
+ }
200
+ ]
201
+ ```
202
+
203
+ * `name`: "name" is an optional field, the value of which should match the name of the directory of the model.
204
+ * `backend`: This field indicates which backend is being used to run the model. Triton supports a wide variety of backends like TensorFlow, PyTorch, Python, ONNX and more. For a complete list of field selection refer to [these comments](https://github.com/triton-inference-server/backend#backends).
205
+ * `max_batch_size`: As the name implies, this field defines the maximum batch size that the model can support.
206
+ * `input` and `output`: The input and output sections specify the name, shape, datatype, and more, while providing operations like [reshaping](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#reshape) and support for [ragged batches](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/ragged_batching.md#ragged-batching).
207
+
208
+ In [most cases](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#auto-generated-model-configuration), it's possible to leave out the `input` and `output` sections and let Triton extract that information from the model files directly. Here, we've included them for clarity and because we'll need to know the names of our output tensors in the client application later on.
209
+
210
+ For details of all supported fields and their values, refer to the [model config protobuf definition file](https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto).
211
+
212
+ ### Launching the server
213
+
214
+ With our repository created and our models configured, we're ready to launch the server. While the Triton Inference Server can be [built from source](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/build.md#building-triton), the use of [pre-built Docker containers](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) freely available from NGC is highly recommended for this example.
215
+
216
+ ```bash
217
+ # Replace the yy.mm in the image name with the release year and month
218
+ # of the Triton version needed, eg. 22.08
219
+
220
+ docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v $(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:<yy.mm>-py3
221
+ ```
222
+
223
+ Once Triton Inference Server has been built or once inside the container, it can be launched with the command:
224
+
225
+ ```bash
226
+ tritonserver --model-repository=/models
227
+ ```
228
+
229
+ This will spin up the server and model instances will be ready for inference.
230
+
231
+ ```text
232
+ I0712 16:37:18.246487 128 server.cc:626]
233
+ +------------------+---------+--------+
234
+ | Model | Version | Status |
235
+ +------------------+---------+--------+
236
+ | text_detection | 1 | READY |
237
+ | text_recognition | 1 | READY |
238
+ +------------------+---------+--------+
239
+
240
+ I0712 16:37:18.267625 128 metrics.cc:650] Collecting metrics for GPU 0: NVIDIA GeForce RTX 3090
241
+ I0712 16:37:18.268041 128 tritonserver.cc:2159]
242
+ +----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
243
+ | Option | Value |
244
+ +----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
245
+ | server_id | triton |
246
+ | server_version | 2.23.0 |
247
+ | server_extensions | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data statistics trace |
248
+ | model_repository_path[0] | /models |
249
+ | model_control_mode | MODE_NONE |
250
+ | strict_model_config | 1 |
251
+ | rate_limit | OFF |
252
+ | pinned_memory_pool_byte_size | 268435456 |
253
+ | cuda_memory_pool_byte_size{0} | 67108864 |
254
+ | response_cache_byte_size | 0 |
255
+ | min_supported_compute_capability | 6.0 |
256
+ | strict_readiness | 1 |
257
+ | exit_timeout | 30 |
258
+ +----------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
259
+
260
+ I0712 16:37:18.269464 128 grpc_server.cc:4587] Started GRPCInferenceService at 0.0.0.0:8001
261
+ I0712 16:37:18.269956 128 http_server.cc:3303] Started HTTPService at 0.0.0.0:8000
262
+ I0712 16:37:18.311686 128 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002
263
+ ```
264
+
265
+ ## Building a client application
266
+
267
+ Now that our Triton server has been launched, we can start sending messages to it. There are three ways to interact with the Triton Inference Server:
268
+
269
+ * HTTP(S) API
270
+ * gRPC API
271
+ * Native C API
272
+
273
+ There are also pre-built [client libraries](https://github.com/triton-inference-server/client#client-library-apis) in [C++](https://github.com/triton-inference-server/client/tree/main/src/c%2B%2B), [Python](https://github.com/triton-inference-server/client/tree/main/src/python), and [Java](https://github.com/triton-inference-server/client/tree/main/src/java) that wrap over the HTTP and gRPC APIs. This example contains a Python client script in `client.py` which uses the `tritonclient` python library to communicate with Triton over the HTTP API.
274
+
275
+ Let's examine the contents of this file:
276
+
277
+ * First, we import our HTTP client from the `tritonclient` library, as well as a few other libraries we'll use for processing our images:
278
+
279
+ ```python
280
+ import math
281
+ import numpy as np
282
+ import cv2
283
+ import tritonclient.http as httpclient
284
+ ```
285
+
286
+ * Next, we'll define a few helper functions for taking care of the pre and post processing steps for our pipeline. The details are omitted here for brevity, but you can check the `client.py` file for more details
287
+
288
+ ```python
289
+ def detection_preprocessing(image: cv2.Mat) -> np.ndarray:
290
+ ...
291
+
292
+ def detection_postprocessing(scores: np.ndarray, geometry: np.ndarray, preprocessed_image: np.ndarray) -> np.ndarray:
293
+ ...
294
+
295
+ def recognition_postprocessing(scores: np.ndarray) -> str:
296
+ ...
297
+ ```
298
+
299
+ * Then, we create a client object, and initialize a connection with the Triton Inference Server.
300
+
301
+ ```python
302
+ client = httpclient.InferenceServerClient(url="localhost:8000")
303
+ ```
304
+
305
+ * Now, we'll create the `InferInput` that we'll be sending to Triton from our data.
306
+
307
+ ```python
308
+ raw_image = cv2.imread("./img2.jpg")
309
+ preprocessed_image = detection_preprocessing(raw_image)
310
+
311
+ detection_input = httpclient.InferInput("input_images:0", preprocessed_image.shape, datatype="FP32")
312
+ detection_input.set_data_from_numpy(preprocessed_image, binary_data=True)
313
+ ```
314
+
315
+ * Finally, we're ready to send an inference request to the Triton Inference Server and retrieve the response
316
+
317
+ ```python
318
+ detection_response = client.infer(model_name="text_detection", inputs=[detection_input])
319
+ ```
320
+
321
+ * After that, we'll repeat the process with the text recognition model, performing our next processing step, creating the input object, querying the server and finally performing postprocessing and printing the result.
322
+
323
+ ```python
324
+ # Process responses from detection model
325
+ scores = detection_response.as_numpy('feature_fusion/Conv_7/Sigmoid:0')
326
+ geometry = detection_response.as_numpy('feature_fusion/concat_3:0')
327
+ cropped_images = detection_postprocessing(scores, geometry, preprocessed_image)
328
+
329
+ # Create input object for recognition model
330
+ recognition_input = httpclient.InferInput("input.1", cropped_images.shape, datatype="FP32")
331
+ recognition_input.set_data_from_numpy(cropped_images, binary_data=True)
332
+
333
+ # Query the server
334
+ recognition_response = client.infer(model_name="text_recognition", inputs=[recognition_input])
335
+
336
+ # Process response from recognition model
337
+ text = recognition_postprocessing(recognition_response.as_numpy('308'))
338
+
339
+ print(text)
340
+ ```
341
+
342
+ Let's try it out!
343
+
344
+ ```bash
345
+ pip install tritonclient[http] opencv-python-headless
346
+ python client.py
347
+ ```
348
+
349
+ You might have noticed that it's a bit redundant to retrieve the results of the first model only to do some processing and send them right back to Triton. In [Part 5](../Part_5-Model_Ensembles/) of this tutorial we explore how you can move more processing steps to the server and execute multiple models in a single network call.
350
+
351
+ ## Model Versioning
352
+
353
+ The ability to deploy different versions of a model is essential to building an MLOps pipeline. The need arises from use cases like conducting A/B tests, easy model version rollbacks and more. Triton users can add a folder and the new model in the same repository:
354
+
355
+ ```text
356
+ model_repository/
357
+ ├── text_detection
358
+ │ ├── 1
359
+ │ │ └── model.onnx
360
+ │ ├── 2
361
+ │ │ └── model.onnx
362
+ │ └── config.pbtxt
363
+ └── text_recognition
364
+ ├── 1
365
+ │ └── model.onnx
366
+ └── config.pbtxt
367
+ ```
368
+
369
+ By default Triton serves the "latest" model, but the policy to serve different versions of the model is customizable. For more information, [refer this guide](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#version-policy).
370
+
371
+ ## Loading & Unloading Models
372
+
373
+ Triton has model management API that can be used to control the model loading unloading policies. This API is extremely useful in cases where one or more models need to be loaded or unloaded without interrupting inference for other models being served on the same server. Users can select from one of three control modes:
374
+
375
+ * NONE
376
+ * EXPLICIT
377
+ * POLL
378
+
379
+ ```bash
380
+ tritonserver --model-repository=/models --model-control-mode=poll
381
+ ```
382
+
383
+ The policies can also be set via command line arguments whilst launching the server. For more information, refer [this section](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md#model-management) of the documentation.
384
+
385
+ # What's next?
386
+
387
+ In this tutorial, we covered the very basics of setting up and querying a Triton Inference Server. This is Part 1 of a 6 part tutorial series that covers the challenges faced in deploying Deep Learning models to production. [Part 2](../Part_2-improving_resource_utilization/) covers `Concurrent Model Execution and Dynamic Batching`. Depending on your workload and experience you might want to jump to [Part 5](../Part_5-Model_Ensembles/) which covers `Building an Ensemble Pipeline with multiple models, pre and post processing steps, and adding business logic`.
client.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ import math
28
+ import numpy as np
29
+ import cv2
30
+ import tritonclient.http as httpclient
31
+
32
+ SAVE_INTERMEDIATE_IMAGES = False
33
+
34
+
35
+ def detection_preprocessing(image: cv2.Mat) -> np.ndarray:
36
+
37
+ inpWidth = 640
38
+ inpHeight = 480
39
+
40
+ # pre-process image
41
+ blob = cv2.dnn.blobFromImage(
42
+ image, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False
43
+ )
44
+ blob = np.transpose(blob, (0, 2, 3, 1))
45
+ return blob
46
+
47
+
48
+ def detection_postprocessing(scores, geometry, preprocessed_image):
49
+ def fourPointsTransform(frame, vertices):
50
+ vertices = np.asarray(vertices)
51
+ outputSize = (100, 32)
52
+ targetVertices = np.array(
53
+ [
54
+ [0, outputSize[1] - 1],
55
+ [0, 0],
56
+ [outputSize[0] - 1, 0],
57
+ [outputSize[0] - 1, outputSize[1] - 1],
58
+ ],
59
+ dtype="float32",
60
+ )
61
+
62
+ rotationMatrix = cv2.getPerspectiveTransform(vertices, targetVertices)
63
+ result = cv2.warpPerspective(frame, rotationMatrix, outputSize)
64
+ return result
65
+
66
+ def decodeBoundingBoxes(scores, geometry, scoreThresh=0.5):
67
+ detections = []
68
+ confidences = []
69
+
70
+ ############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ########
71
+ assert len(scores.shape) == 4, "Incorrect dimensions of scores"
72
+ assert len(geometry.shape) == 4, "Incorrect dimensions of geometry"
73
+ assert scores.shape[0] == 1, "Invalid dimensions of scores"
74
+ assert geometry.shape[0] == 1, "Invalid dimensions of geometry"
75
+ assert scores.shape[1] == 1, "Invalid dimensions of scores"
76
+ assert geometry.shape[1] == 5, "Invalid dimensions of geometry"
77
+ assert (
78
+ scores.shape[2] == geometry.shape[2]
79
+ ), "Invalid dimensions of scores and geometry"
80
+ assert (
81
+ scores.shape[3] == geometry.shape[3]
82
+ ), "Invalid dimensions of scores and geometry"
83
+ height = scores.shape[2]
84
+ width = scores.shape[3]
85
+ for y in range(0, height):
86
+ # Extract data from scores
87
+ scoresData = scores[0][0][y]
88
+ x0_data = geometry[0][0][y]
89
+ x1_data = geometry[0][1][y]
90
+ x2_data = geometry[0][2][y]
91
+ x3_data = geometry[0][3][y]
92
+ anglesData = geometry[0][4][y]
93
+ for x in range(0, width):
94
+ score = scoresData[x]
95
+
96
+ # If score is lower than threshold score, move to next x
97
+ if score < scoreThresh:
98
+ continue
99
+
100
+ # Calculate offset
101
+ offsetX = x * 4.0
102
+ offsetY = y * 4.0
103
+ angle = anglesData[x]
104
+
105
+ # Calculate cos and sin of angle
106
+ cosA = math.cos(angle)
107
+ sinA = math.sin(angle)
108
+ h = x0_data[x] + x2_data[x]
109
+ w = x1_data[x] + x3_data[x]
110
+
111
+ # Calculate offset
112
+ offset = [
113
+ offsetX + cosA * x1_data[x] + sinA * x2_data[x],
114
+ offsetY - sinA * x1_data[x] + cosA * x2_data[x],
115
+ ]
116
+
117
+ # Find points for rectangle
118
+ p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
119
+ p3 = (-cosA * w + offset[0], sinA * w + offset[1])
120
+ center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
121
+ detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
122
+ confidences.append(float(score))
123
+
124
+ # Return detections and confidences
125
+ return [detections, confidences]
126
+
127
+ scores = scores.transpose(0, 3, 1, 2)
128
+ geometry = geometry.transpose(0, 3, 1, 2)
129
+ frame = np.squeeze(preprocessed_image, axis=0)
130
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
+ [boxes, confidences] = decodeBoundingBoxes(scores, geometry)
132
+ indices = cv2.dnn.NMSBoxesRotated(boxes, confidences, 0.5, 0.4)
133
+
134
+ cropped_list = []
135
+ cv2.imwrite("frame.png", frame)
136
+ count = 0
137
+ for i in indices:
138
+ # get 4 corners of the rotated rect
139
+ count += 1
140
+ vertices = cv2.boxPoints(boxes[i])
141
+ cropped = fourPointsTransform(frame, vertices)
142
+ cv2.imwrite(str(count) + ".png", cropped)
143
+ cropped = np.expand_dims(cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY), axis=0)
144
+
145
+ cropped_list.append(((cropped / 255.0) - 0.5) * 2)
146
+ cropped_arr = np.stack(cropped_list, axis=0)
147
+
148
+ # Only keep the first image, since the models don't currently allow batching.
149
+ # See part 2 for enabling batch sizes > 0
150
+ return cropped_arr[None, 0]
151
+
152
+
153
+ def recognition_postprocessing(scores: np.ndarray) -> str:
154
+ text = ""
155
+ alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
156
+
157
+ scores = np.transpose(scores, (1,0,2))
158
+
159
+ for i in range(scores.shape[0]):
160
+ c = np.argmax(scores[i][0])
161
+ if c != 0:
162
+ text += alphabet[c - 1]
163
+ else:
164
+ text += "-"
165
+ # adjacent same letters as well as background text must be removed
166
+ # to get the final output
167
+ char_list = []
168
+ for i, char in enumerate(text):
169
+ if char != "-" and (not (i > 0 and char == text[i - 1])):
170
+ char_list.append(char)
171
+ return "".join(char_list)
172
+
173
+
174
+ if __name__ == "__main__":
175
+
176
+ # Setting up client
177
+ client = httpclient.InferenceServerClient(url="localhost:8000")
178
+
179
+ # Read image and create input object
180
+ raw_image = cv2.imread("./img1.jpg")
181
+ preprocessed_image = detection_preprocessing(raw_image)
182
+
183
+ detection_input = httpclient.InferInput(
184
+ "input_images:0", preprocessed_image.shape, datatype="FP32"
185
+ )
186
+ detection_input.set_data_from_numpy(preprocessed_image, binary_data=True)
187
+
188
+ # Query the server
189
+ detection_response = client.infer(
190
+ model_name="text_detection", inputs=[detection_input]
191
+ )
192
+
193
+ # Process responses from detection model
194
+ scores = detection_response.as_numpy("feature_fusion/Conv_7/Sigmoid:0")
195
+ geometry = detection_response.as_numpy("feature_fusion/concat_3:0")
196
+ cropped_images = detection_postprocessing(scores, geometry, preprocessed_image)
197
+
198
+ # Create input object for recognition model
199
+ recognition_input = httpclient.InferInput(
200
+ "input.1", cropped_images.shape, datatype="FP32"
201
+ )
202
+ recognition_input.set_data_from_numpy(cropped_images, binary_data=True)
203
+
204
+ # Query the server
205
+ recognition_response = client.infer(
206
+ model_name="text_recognition", inputs=[recognition_input]
207
+ )
208
+
209
+ # Process response from recognition model
210
+ final_text = recognition_postprocessing(recognition_response.as_numpy("308"))
211
+
212
+ print(final_text)
convert_to_str.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.model import STRModel
3
+
4
+ # Create PyTorch Model Object
5
+ model = STRModel(input_channels=1, output_channels=512, num_classes=37)
6
+
7
+ # Load model weights from external file
8
+ state = torch.load("None-ResNet-None-CTC.pth", map_location=torch.device('cpu'))
9
+ state = {key.replace("module.", ""): value for key, value in state.items()}
10
+ model.load_state_dict(state)
11
+
12
+ # Create ONNX file by tracing model
13
+ trace_input = torch.randn(1, 1, 32, 100)
14
+ torch.onnx.export(model, trace_input, "str.onnx", verbose=True)
frozen_east_text_detection.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b486f3c3eee77b4c8cc91a83892c37026cca7d29b79bf3b93772ccd2db58454
3
+ size 96662756
frozen_east_text_detection.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba5e7c7e9080c3014ce5f253a397a72c7be0d83b51d0a6f9fc6b0ad80db7fe9d
3
+ size 89824490
img/multiple_models.PNG ADDED
img1.jpg ADDED

Git LFS Details

  • SHA256: 09ad97312c2afb63040ebc28a893062b68b561607e05ce83c3a7f302bdc537fd
  • Pointer size: 131 Bytes
  • Size of remote file: 873 kB
model_repository/text_detection/1/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92c96959955dcf92ef5a9bba0c5f1979e05f8556ae68f35f0accbe7a68b0e43c
3
+ size 96224417
model_repository/text_detection/config.pbtxt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "text_detection"
28
+ backend: "onnxruntime"
29
+ max_batch_size : 0
30
+ input [
31
+ {
32
+ name: "input_images:0"
33
+ data_type: TYPE_FP32
34
+ dims: [ -1, -1, -1, 3 ]
35
+ }
36
+ ]
37
+ output [
38
+ {
39
+ name: "feature_fusion/Conv_7/Sigmoid:0"
40
+ data_type: TYPE_FP32
41
+ dims: [ -1, -1, -1, 1 ]
42
+ }
43
+ ]
44
+ output [
45
+ {
46
+ name: "feature_fusion/concat_3:0"
47
+ data_type: TYPE_FP32
48
+ dims: [ -1, -1, -1, 5 ]
49
+ }
50
+ ]
model_repository/text_recognition/1/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73e18fde8ed993862ecc7636fed1f70ff810bee1345bef25f79e549a17fcb347
3
+ size 177304847
model_repository/text_recognition/config.pbtxt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Redistribution and use in source and binary forms, with or without
4
+ # modification, are permitted provided that the following conditions
5
+ # are met:
6
+ # * Redistributions of source code must retain the above copyright
7
+ # notice, this list of conditions and the following disclaimer.
8
+ # * Redistributions in binary form must reproduce the above copyright
9
+ # notice, this list of conditions and the following disclaimer in the
10
+ # documentation and/or other materials provided with the distribution.
11
+ # * Neither the name of NVIDIA CORPORATION nor the names of its
12
+ # contributors may be used to endorse or promote products derived
13
+ # from this software without specific prior written permission.
14
+ #
15
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18
+ # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19
+ # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21
+ # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22
+ # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+
27
+ name: "text_recognition"
28
+ backend: "onnxruntime"
29
+ max_batch_size : 0
30
+ input [
31
+ {
32
+ name: "input.1"
33
+ data_type: TYPE_FP32
34
+ dims: [ 1, 1, 32, 100 ]
35
+ }
36
+ ]
37
+ output [
38
+ {
39
+ name: "308"
40
+ data_type: TYPE_FP32
41
+ dims: [ 1, 26, 37 ]
42
+ }
43
+ ]
utils/model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch.nn as nn
18
+
19
+ # from modules.feature_extraction import ResNet_FeatureExtractor
20
+
21
+
22
+ class ResNet_FeatureExtractor(nn.Module):
23
+ """FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf)"""
24
+
25
+ def __init__(self, input_channel, output_channel=512):
26
+ super(ResNet_FeatureExtractor, self).__init__()
27
+ self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
28
+
29
+ def forward(self, input):
30
+ return self.ConvNet(input)
31
+
32
+
33
+ class BasicBlock(nn.Module):
34
+ expansion = 1
35
+
36
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
37
+ super(BasicBlock, self).__init__()
38
+ self.conv1 = self._conv3x3(inplanes, planes)
39
+ self.bn1 = nn.BatchNorm2d(planes)
40
+ self.conv2 = self._conv3x3(planes, planes)
41
+ self.bn2 = nn.BatchNorm2d(planes)
42
+ self.relu = nn.ReLU(inplace=True)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def _conv3x3(self, in_planes, out_planes, stride=1):
47
+ "3x3 convolution with padding"
48
+ return nn.Conv2d(
49
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
50
+ )
51
+
52
+ def forward(self, x):
53
+ residual = x
54
+
55
+ out = self.conv1(x)
56
+ out = self.bn1(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv2(out)
60
+ out = self.bn2(out)
61
+
62
+ if self.downsample is not None:
63
+ residual = self.downsample(x)
64
+ out += residual
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class ResNet(nn.Module):
71
+ def __init__(self, input_channel, output_channel, block, layers):
72
+ super(ResNet, self).__init__()
73
+
74
+ self.output_channel_block = [
75
+ int(output_channel / 4),
76
+ int(output_channel / 2),
77
+ output_channel,
78
+ output_channel,
79
+ ]
80
+
81
+ self.inplanes = int(output_channel / 8)
82
+ self.conv0_1 = nn.Conv2d(
83
+ input_channel,
84
+ int(output_channel / 16),
85
+ kernel_size=3,
86
+ stride=1,
87
+ padding=1,
88
+ bias=False,
89
+ )
90
+ self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
91
+ self.conv0_2 = nn.Conv2d(
92
+ int(output_channel / 16),
93
+ self.inplanes,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ )
99
+ self.bn0_2 = nn.BatchNorm2d(self.inplanes)
100
+ self.relu = nn.ReLU(inplace=True)
101
+
102
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
103
+ self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
104
+ self.conv1 = nn.Conv2d(
105
+ self.output_channel_block[0],
106
+ self.output_channel_block[0],
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1,
110
+ bias=False,
111
+ )
112
+ self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
113
+
114
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
115
+ self.layer2 = self._make_layer(
116
+ block, self.output_channel_block[1], layers[1], stride=1
117
+ )
118
+ self.conv2 = nn.Conv2d(
119
+ self.output_channel_block[1],
120
+ self.output_channel_block[1],
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1,
124
+ bias=False,
125
+ )
126
+ self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
127
+
128
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
129
+ self.layer3 = self._make_layer(
130
+ block, self.output_channel_block[2], layers[2], stride=1
131
+ )
132
+ self.conv3 = nn.Conv2d(
133
+ self.output_channel_block[2],
134
+ self.output_channel_block[2],
135
+ kernel_size=3,
136
+ stride=1,
137
+ padding=1,
138
+ bias=False,
139
+ )
140
+ self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
141
+
142
+ self.layer4 = self._make_layer(
143
+ block, self.output_channel_block[3], layers[3], stride=1
144
+ )
145
+ self.conv4_1 = nn.Conv2d(
146
+ self.output_channel_block[3],
147
+ self.output_channel_block[3],
148
+ kernel_size=2,
149
+ stride=(2, 1),
150
+ padding=(0, 1),
151
+ bias=False,
152
+ )
153
+ self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
154
+ self.conv4_2 = nn.Conv2d(
155
+ self.output_channel_block[3],
156
+ self.output_channel_block[3],
157
+ kernel_size=2,
158
+ stride=1,
159
+ padding=0,
160
+ bias=False,
161
+ )
162
+ self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
163
+
164
+ def _make_layer(self, block, planes, blocks, stride=1):
165
+ downsample = None
166
+ if stride != 1 or self.inplanes != planes * block.expansion:
167
+ downsample = nn.Sequential(
168
+ nn.Conv2d(
169
+ self.inplanes,
170
+ planes * block.expansion,
171
+ kernel_size=1,
172
+ stride=stride,
173
+ bias=False,
174
+ ),
175
+ nn.BatchNorm2d(planes * block.expansion),
176
+ )
177
+
178
+ layers = []
179
+ layers.append(block(self.inplanes, planes, stride, downsample))
180
+ self.inplanes = planes * block.expansion
181
+ for i in range(1, blocks):
182
+ layers.append(block(self.inplanes, planes))
183
+
184
+ return nn.Sequential(*layers)
185
+
186
+ def forward(self, x):
187
+ x = self.conv0_1(x)
188
+ x = self.bn0_1(x)
189
+ x = self.relu(x)
190
+ x = self.conv0_2(x)
191
+ x = self.bn0_2(x)
192
+ x = self.relu(x)
193
+
194
+ x = self.maxpool1(x)
195
+ x = self.layer1(x)
196
+ x = self.conv1(x)
197
+ x = self.bn1(x)
198
+ x = self.relu(x)
199
+
200
+ x = self.maxpool2(x)
201
+ x = self.layer2(x)
202
+ x = self.conv2(x)
203
+ x = self.bn2(x)
204
+ x = self.relu(x)
205
+
206
+ x = self.maxpool3(x)
207
+ x = self.layer3(x)
208
+ x = self.conv3(x)
209
+ x = self.bn3(x)
210
+ x = self.relu(x)
211
+
212
+ x = self.layer4(x)
213
+ x = self.conv4_1(x)
214
+ x = self.bn4_1(x)
215
+ x = self.relu(x)
216
+ x = self.conv4_2(x)
217
+ x = self.bn4_2(x)
218
+ x = self.relu(x)
219
+
220
+ return x
221
+
222
+
223
+ class STRModel(nn.Module):
224
+ def __init__(self, input_channels, output_channels, num_classes):
225
+ super(STRModel, self).__init__()
226
+ self.FeatureExtraction = ResNet_FeatureExtractor(
227
+ input_channels, output_channels
228
+ )
229
+ self.FeatureExtraction_output = output_channels # int(imgH/16-1) * 512
230
+ self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
231
+ (self.FeatureExtraction_output, 1)
232
+ ) # Transform final (imgH/16-1) -> 1
233
+ self.SequenceModeling_output = self.FeatureExtraction_output
234
+ self.Prediction = nn.Linear(self.SequenceModeling_output, num_classes)
235
+
236
+ def forward(self, input):
237
+
238
+ """Feature extraction stage"""
239
+ visual_feature = self.FeatureExtraction(input)
240
+ visual_feature = self.AdaptiveAvgPool(
241
+ visual_feature.permute(0, 3, 1, 2)
242
+ ) # [b, c, h, w] -> [b, w, c, h]
243
+ visual_feature = visual_feature.squeeze(3)
244
+
245
+ """ Prediction stage """
246
+ prediction = self.Prediction(visual_feature.contiguous())
247
+
248
+ return prediction