Spaces:
Running
Running
Sébastien De Greef
commited on
Commit
•
db1f0f8
1
Parent(s):
f32aa09
feat: Add slideshow on optimizers in neural networks
Browse files- src/theory/ActivationFunctions.ipynb +0 -0
- src/theory/activations.qmd +264 -45
- src/theory/activations_slideshow.qmd +355 -0
- src/theory/architectures.qmd +734 -75
- src/theory/layers.qmd +30 -15
- src/theory/metrics.qmd +443 -19
- src/theory/optimizers.qmd +89 -0
- src/theory/optimizers_slideshow.qmd +100 -0
- src/theory/training.qmd +191 -0
src/theory/ActivationFunctions.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/theory/activations.qmd
CHANGED
@@ -1,7 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
|
6 |
**Strengths:** Maps any real-valued number to a value between 0 and 1, making it suitable for binary classification problems.
|
7 |
|
@@ -9,9 +30,24 @@
|
|
9 |
|
10 |
**Usage:** Binary classification, logistic regression.
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
**Strengths:** Similar to sigmoid, but maps to (-1, 1), which can be beneficial for some models.
|
17 |
|
@@ -19,9 +55,24 @@
|
|
19 |
|
20 |
**Usage:** Similar to sigmoid, but with a larger output range.
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
|
26 |
**Strengths:** Computationally efficient, non-saturating, and easy to compute.
|
27 |
|
@@ -29,9 +80,24 @@
|
|
29 |
|
30 |
**Usage:** Default activation function in many deep learning frameworks, suitable for most neural networks.
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
**Strengths:** Similar to ReLU, but allows a small fraction of the input to pass through, helping with dying neurons.
|
37 |
|
@@ -39,9 +105,33 @@
|
|
39 |
|
40 |
**Usage:** Alternative to ReLU, especially when dealing with dying neurons.
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
**Strengths:** Self-gated, adaptive, and non-saturating.
|
47 |
|
@@ -49,9 +139,53 @@
|
|
49 |
|
50 |
**Usage:** Can be used in place of ReLU or other activations, but may not always outperform them.
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
**Strengths:** Normalizes output to ensure probabilities sum to 1, making it suitable for multi-class classification.
|
57 |
|
@@ -59,9 +193,25 @@
|
|
59 |
|
60 |
**Usage:** Output layer activation for multi-class classification problems.
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
**Strengths:** Similar to sigmoid, but with a more gradual slope.
|
67 |
|
@@ -69,29 +219,74 @@
|
|
69 |
|
70 |
**Usage:** Alternative to sigmoid or tanh in certain situations.
|
71 |
|
72 |
-
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
**Weaknesses:** Not commonly used, may not outperform other activations.
|
79 |
|
80 |
**Usage:** Experimental or niche applications.
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
**Weaknesses:** Not commonly used, may not outperform other activations.
|
89 |
|
90 |
**Usage:** Experimental or niche applications.
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
97 |
|
@@ -99,9 +294,28 @@
|
|
99 |
|
100 |
**Usage:** Alternative to ReLU, especially in Bayesian neural networks.
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
|
106 |
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
107 |
|
@@ -109,19 +323,29 @@
|
|
109 |
|
110 |
**Usage:** Alternative to ReLU, especially in computer vision tasks.
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
**Formula:** f(x) = x \* sigmoid(x)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
|
120 |
-
|
|
|
|
|
|
|
121 |
|
122 |
-
##
|
123 |
|
124 |
-
|
|
|
|
|
125 |
|
126 |
**Strengths:** Fast, non-saturating, and smooth.
|
127 |
|
@@ -129,9 +353,15 @@
|
|
129 |
|
130 |
**Usage:** Alternative to GELU, especially when computational efficiency is crucial.
|
131 |
|
132 |
-
##
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
**Strengths:** Self-normalizing, non-saturating, and computationally efficient.
|
137 |
|
@@ -139,14 +369,3 @@
|
|
139 |
|
140 |
**Usage:** Alternative to ReLU, especially in deep neural networks.
|
141 |
|
142 |
-
When choosing an activation function, consider the following:
|
143 |
-
|
144 |
-
* **Non-saturation:** Avoid activations that saturate (e.g., sigmoid, tanh) to prevent vanishing gradients.
|
145 |
-
|
146 |
-
* **Computational efficiency:** Choose activations that are computationally efficient (e.g., ReLU, Swish) for large models or real-time applications.
|
147 |
-
|
148 |
-
* **Smoothness:** Smooth activations (e.g., GELU, Mish) can help with optimization and convergence.
|
149 |
-
|
150 |
-
* **Domain knowledge:** Select activations based on the problem domain and desired output (e.g., softmax for multi-class classification).
|
151 |
-
|
152 |
-
* **Experimentation:** Try different activations and evaluate their performance on your specific task.
|
|
|
1 |
+
---
|
2 |
+
title: "Activation functions"
|
3 |
+
notebook-links: false
|
4 |
+
crossref:
|
5 |
+
lof-title: "List of Figures"
|
6 |
+
number-sections: false
|
7 |
+
---
|
8 |
|
9 |
+
When choosing an activation function, consider the following:
|
10 |
+
|
11 |
+
- **Non-saturation:** Avoid activations that saturate (e.g., sigmoid, tanh) to prevent vanishing gradients.
|
12 |
+
|
13 |
+
- **Computational efficiency:** Choose activations that are computationally efficient (e.g., ReLU, Swish) for large models or real-time applications.
|
14 |
+
|
15 |
+
- **Smoothness:** Smooth activations (e.g., GELU, Mish) can help with optimization and convergence.
|
16 |
+
|
17 |
+
- **Domain knowledge:** Select activations based on the problem domain and desired output (e.g., softmax for multi-class classification).
|
18 |
+
|
19 |
+
- **Experimentation:** Try different activations and evaluate their performance on your specific task.
|
20 |
+
|
21 |
+
[Slideshow](activations_slideshow.qmd)
|
22 |
+
|
23 |
+
{{< embed ActivationFunctions.ipynb#fig-overview >}}
|
24 |
|
25 |
+
## Sigmoid {#sec-sigmoid}
|
26 |
|
27 |
**Strengths:** Maps any real-valued number to a value between 0 and 1, making it suitable for binary classification problems.
|
28 |
|
|
|
30 |
|
31 |
**Usage:** Binary classification, logistic regression.
|
32 |
|
33 |
+
::: columns
|
34 |
+
::: {.column width="50%"}
|
35 |
+
$$
|
36 |
+
\sigma(x) = \frac{1}{1 + e^{-x}}
|
37 |
+
$$
|
38 |
|
39 |
+
``` python
|
40 |
+
def sigmoid(x):
|
41 |
+
return 1 / (1 + np.exp(-x))
|
42 |
+
```
|
43 |
+
:::
|
44 |
+
|
45 |
+
::: {.column width="50%"}
|
46 |
+
{{< embed ActivationFunctions.ipynb#fig-sigmoid >}}
|
47 |
+
:::
|
48 |
+
:::
|
49 |
+
|
50 |
+
## Hyperbolic Tangent (Tanh) {#sec-tanh}
|
51 |
|
52 |
**Strengths:** Similar to sigmoid, but maps to (-1, 1), which can be beneficial for some models.
|
53 |
|
|
|
55 |
|
56 |
**Usage:** Similar to sigmoid, but with a larger output range.
|
57 |
|
58 |
+
::: columns
|
59 |
+
::: {.column width="50%"}
|
60 |
+
$$
|
61 |
+
\tanh(x) = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
|
62 |
+
$$
|
63 |
+
|
64 |
+
``` python
|
65 |
+
def tanh(x):
|
66 |
+
return np.tanh(x)
|
67 |
+
```
|
68 |
+
:::
|
69 |
+
|
70 |
+
::: {.column width="50%"}
|
71 |
+
{{< embed ActivationFunctions.ipynb#fig-tanh >}}
|
72 |
+
:::
|
73 |
+
:::
|
74 |
|
75 |
+
## Rectified Linear Unit (ReLU)
|
76 |
|
77 |
**Strengths:** Computationally efficient, non-saturating, and easy to compute.
|
78 |
|
|
|
80 |
|
81 |
**Usage:** Default activation function in many deep learning frameworks, suitable for most neural networks.
|
82 |
|
83 |
+
::: columns
|
84 |
+
::: {.column width="50%"}
|
85 |
+
$$
|
86 |
+
\text{ReLU}(x) = \max(0, x)
|
87 |
+
$$
|
88 |
|
89 |
+
``` python
|
90 |
+
def relu(x):
|
91 |
+
return np.maximum(0, x)
|
92 |
+
```
|
93 |
+
:::
|
94 |
+
|
95 |
+
::: {.column width="50%"}
|
96 |
+
{{< embed ActivationFunctions.ipynb#fig-relu >}}
|
97 |
+
:::
|
98 |
+
:::
|
99 |
+
|
100 |
+
## Leaky ReLU
|
101 |
|
102 |
**Strengths:** Similar to ReLU, but allows a small fraction of the input to pass through, helping with dying neurons.
|
103 |
|
|
|
105 |
|
106 |
**Usage:** Alternative to ReLU, especially when dealing with dying neurons.
|
107 |
|
108 |
+
::: columns
|
109 |
+
::: {.column width="50%"}
|
110 |
+
$$
|
111 |
+
\text{Leaky ReLU}(x) =
|
112 |
+
\begin{cases}
|
113 |
+
x & \text{if } x > 0 \\
|
114 |
+
\alpha x & \text{if } x \leq 0
|
115 |
+
\end{cases}
|
116 |
+
$$
|
117 |
+
|
118 |
+
``` python
|
119 |
+
def leaky_relu(x, alpha=0.01):
|
120 |
+
# where α is a small constant (e.g., 0.01)
|
121 |
+
return np.where(x > 0, x, x * alpha)
|
122 |
+
```
|
123 |
+
:::
|
124 |
|
125 |
+
::: {.column width="50%"}
|
126 |
+
{{< embed ActivationFunctions.ipynb#fig-leaky_relu >}}
|
127 |
+
:::
|
128 |
+
:::
|
129 |
+
|
130 |
+
## Swish
|
131 |
+
|
132 |
+
**Formula:**
|
133 |
+
|
134 |
+
where g(x) is a learned function (e.g., sigmoid or ReLU)
|
135 |
|
136 |
**Strengths:** Self-gated, adaptive, and non-saturating.
|
137 |
|
|
|
139 |
|
140 |
**Usage:** Can be used in place of ReLU or other activations, but may not always outperform them.
|
141 |
|
142 |
+
::: columns
|
143 |
+
::: {.column width="50%"}
|
144 |
+
$$
|
145 |
+
\text{Swish}(x) = x \cdot \sigma(x)
|
146 |
+
$$
|
147 |
+
|
148 |
+
``` python
|
149 |
+
def swish(x):
|
150 |
+
return x * sigmoid(x)
|
151 |
+
```
|
152 |
+
|
153 |
+
See also: [sigmoid](#sec-sigmoid)
|
154 |
+
:::
|
155 |
+
|
156 |
+
::: {.column width="50%"}
|
157 |
+
{{< embed ActivationFunctions.ipynb#fig-swish >}}
|
158 |
+
:::
|
159 |
+
:::
|
160 |
+
|
161 |
+
## Mish
|
162 |
+
|
163 |
+
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
164 |
+
|
165 |
+
**Weaknesses:** Not as well-studied as ReLU or other activations.
|
166 |
+
|
167 |
+
**Usage:** Alternative to ReLU, especially in computer vision tasks.
|
168 |
+
|
169 |
+
::: columns
|
170 |
+
::: {.column width="50%"}
|
171 |
+
$$
|
172 |
+
\text{Mish}(x) = x \cdot \tanh(\text{Softplus}(x))
|
173 |
+
$$
|
174 |
|
175 |
+
``` python
|
176 |
+
def mish(x):
|
177 |
+
return x * np.tanh(softplus(x))
|
178 |
+
```
|
179 |
+
:::
|
180 |
+
|
181 |
+
::: {.column width="50%"}
|
182 |
+
{{< embed ActivationFunctions.ipynb#fig-mish >}}
|
183 |
+
:::
|
184 |
+
:::
|
185 |
+
|
186 |
+
See also: [softplus](#softplus) [tanh](#sec-tanh)
|
187 |
+
|
188 |
+
## Softmax
|
189 |
|
190 |
**Strengths:** Normalizes output to ensure probabilities sum to 1, making it suitable for multi-class classification.
|
191 |
|
|
|
193 |
|
194 |
**Usage:** Output layer activation for multi-class classification problems.
|
195 |
|
196 |
+
::: columns
|
197 |
+
::: {.column width="50%"}
|
198 |
+
$$
|
199 |
+
\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{k=1}^{K} e^{x_k}}
|
200 |
+
$$
|
201 |
+
|
202 |
+
``` python
|
203 |
+
def softmax(x):
|
204 |
+
e_x = np.exp(x - np.max(x))
|
205 |
+
return e_x / e_x.sum()
|
206 |
+
```
|
207 |
+
:::
|
208 |
|
209 |
+
::: {.column width="50%"}
|
210 |
+
{{< embed ActivationFunctions.ipynb#fig-softmax >}}
|
211 |
+
:::
|
212 |
+
:::
|
213 |
+
|
214 |
+
## Softsign
|
215 |
|
216 |
**Strengths:** Similar to sigmoid, but with a more gradual slope.
|
217 |
|
|
|
219 |
|
220 |
**Usage:** Alternative to sigmoid or tanh in certain situations.
|
221 |
|
222 |
+
::: columns
|
223 |
+
::: {.column width="50%"}
|
224 |
+
$$
|
225 |
+
\text{Softsign}(x) = \frac{x}{1 + |x|}
|
226 |
+
$$
|
227 |
|
228 |
+
``` python
|
229 |
+
def softsign(x):
|
230 |
+
return x / (1 + np.abs(x))
|
231 |
+
```
|
232 |
+
:::
|
233 |
|
234 |
+
::: {.column width="50%"}
|
235 |
+
{{< embed ActivationFunctions.ipynb#fig-softsign >}}
|
236 |
+
:::
|
237 |
+
:::
|
238 |
+
|
239 |
+
## SoftPlus {#softplus}
|
240 |
+
|
241 |
+
**Strengths:** Smooth, continuous, and non-saturating.
|
242 |
|
243 |
**Weaknesses:** Not commonly used, may not outperform other activations.
|
244 |
|
245 |
**Usage:** Experimental or niche applications.
|
246 |
|
247 |
+
::: columns
|
248 |
+
::: {.column width="50%"}
|
249 |
+
$$
|
250 |
+
\text{Softplus}(x) = \log(1 + e^x)
|
251 |
+
$$
|
252 |
|
253 |
+
``` python
|
254 |
+
def softplus(x):
|
255 |
+
return np.log1p(np.exp(x))
|
256 |
+
```
|
257 |
+
:::
|
258 |
|
259 |
+
::: {.column width="50%"}
|
260 |
+
{{< embed ActivationFunctions.ipynb#fig-softplus >}}
|
261 |
+
:::
|
262 |
+
:::
|
263 |
+
|
264 |
+
## ArcTan
|
265 |
+
|
266 |
+
**Strengths:** Non-saturating, smooth, and continuous.
|
267 |
|
268 |
**Weaknesses:** Not commonly used, may not outperform other activations.
|
269 |
|
270 |
**Usage:** Experimental or niche applications.
|
271 |
|
272 |
+
::: columns
|
273 |
+
::: {.column width="50%"}
|
274 |
+
$$
|
275 |
+
arctan(x) = arctan(x)
|
276 |
+
$$
|
277 |
+
|
278 |
+
``` python
|
279 |
+
def arctan(x):
|
280 |
+
return np.arctan(x)
|
281 |
+
```
|
282 |
+
:::
|
283 |
|
284 |
+
::: {.column width="50%"}
|
285 |
+
{{< embed ActivationFunctions.ipynb#fig-arctan >}}
|
286 |
+
:::
|
287 |
+
:::
|
288 |
+
|
289 |
+
## Gaussian Error Linear Unit (GELU)
|
290 |
|
291 |
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
292 |
|
|
|
294 |
|
295 |
**Usage:** Alternative to ReLU, especially in Bayesian neural networks.
|
296 |
|
297 |
+
::: columns
|
298 |
+
::: {.column width="50%"}
|
299 |
+
$$
|
300 |
+
\text{GELU}(x) = x \cdot \Phi(x)
|
301 |
+
$$
|
302 |
+
|
303 |
+
``` python
|
304 |
+
def gelu(x):
|
305 |
+
return 0.5 * x
|
306 |
+
* (1 + np.tanh(np.sqrt(2 / np.pi)
|
307 |
+
* (x + 0.044715 * np.power(x, 3))))
|
308 |
+
```
|
309 |
+
:::
|
310 |
+
|
311 |
+
::: {.column width="50%"}
|
312 |
+
{{< embed ActivationFunctions.ipynb#fig-gelu >}}
|
313 |
+
:::
|
314 |
+
:::
|
315 |
+
|
316 |
+
See also: [tanh](#sec-tanh)
|
317 |
|
318 |
+
## Silu (SiLU)
|
319 |
|
320 |
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
321 |
|
|
|
323 |
|
324 |
**Usage:** Alternative to ReLU, especially in computer vision tasks.
|
325 |
|
326 |
+
::: columns
|
327 |
+
::: {.column width="50%"}
|
328 |
+
$$
|
329 |
+
silu(x) = x * sigmoid(x)
|
330 |
+
$$
|
331 |
|
|
|
332 |
|
333 |
+
``` python
|
334 |
+
def silu(x):
|
335 |
+
return x / (1 + np.exp(-x))
|
336 |
+
```
|
337 |
+
:::
|
338 |
|
339 |
+
::: {.column width="50%"}
|
340 |
+
{{< embed ActivationFunctions.ipynb#fig-silu >}}
|
341 |
+
:::
|
342 |
+
:::
|
343 |
|
344 |
+
## GELU Approximation (GELU Approx.)
|
345 |
|
346 |
+
$$
|
347 |
+
f(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
|
348 |
+
$$
|
349 |
|
350 |
**Strengths:** Fast, non-saturating, and smooth.
|
351 |
|
|
|
353 |
|
354 |
**Usage:** Alternative to GELU, especially when computational efficiency is crucial.
|
355 |
|
356 |
+
## SELU (Scaled Exponential Linear Unit)
|
357 |
|
358 |
+
$$
|
359 |
+
f(x) = \lambda
|
360 |
+
\begin{cases}
|
361 |
+
x & x > 0 \\
|
362 |
+
\alpha e^x - \alpha & x \leq 0
|
363 |
+
\end{cases}
|
364 |
+
$$
|
365 |
|
366 |
**Strengths:** Self-normalizing, non-saturating, and computationally efficient.
|
367 |
|
|
|
369 |
|
370 |
**Usage:** Alternative to ReLU, especially in deep neural networks.
|
371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/theory/activations_slideshow.qmd
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Activation functions in Neural Networks"
|
3 |
+
author: "Sébastien De Greef"
|
4 |
+
format:
|
5 |
+
revealjs:
|
6 |
+
theme: solarized
|
7 |
+
navigation-mode: grid
|
8 |
+
controls-layout: bottom-right
|
9 |
+
controls-tutorial: true
|
10 |
+
---
|
11 |
+
|
12 |
+
# Activation functions
|
13 |
+
|
14 |
+
When choosing an activation function, consider the following:
|
15 |
+
|
16 |
+
- **Non-saturation:** Avoid activations that saturate (e.g., sigmoid, tanh) to prevent vanishing gradients.
|
17 |
+
|
18 |
+
- **Computational efficiency:** Choose activations that are computationally efficient (e.g., ReLU, Swish) for large models or real-time applications.
|
19 |
+
|
20 |
+
- **Smoothness:** Smooth activations (e.g., GELU, Mish) can help with optimization and convergence.
|
21 |
+
|
22 |
+
- **Domain knowledge:** Select activations based on the problem domain and desired output (e.g., softmax for multi-class classification).
|
23 |
+
|
24 |
+
- **Experimentation:** Try different activations and evaluate their performance on your specific task.
|
25 |
+
|
26 |
+
## Sigmoid {#sec-sigmoid}
|
27 |
+
|
28 |
+
**Strengths:** Maps any real-valued number to a value between 0 and 1, making it suitable for binary classification problems.
|
29 |
+
|
30 |
+
**Weaknesses:** Saturates (i.e., output values approach 0 or 1) for large inputs, leading to vanishing gradients during backpropagation.
|
31 |
+
|
32 |
+
**Usage:** Binary classification, logistic regression.
|
33 |
+
|
34 |
+
::: columns
|
35 |
+
::: {.column width="50%"}
|
36 |
+
$$
|
37 |
+
\sigma(x) = \frac{1}{1 + e^{-x}}
|
38 |
+
$$
|
39 |
+
|
40 |
+
``` python
|
41 |
+
def sigmoid(x):
|
42 |
+
return 1 / (1 + np.exp(-x))
|
43 |
+
```
|
44 |
+
:::
|
45 |
+
|
46 |
+
::: {.column width="50%"}
|
47 |
+
{{< embed ActivationFunctions.ipynb#fig-sigmoid >}}
|
48 |
+
:::
|
49 |
+
:::
|
50 |
+
|
51 |
+
## Hyperbolic Tangent (Tanh) {#sec-tanh}
|
52 |
+
|
53 |
+
**Strengths:** Similar to sigmoid, but maps to (-1, 1), which can be beneficial for some models.
|
54 |
+
|
55 |
+
**Weaknesses:** Also saturates, leading to vanishing gradients.
|
56 |
+
|
57 |
+
**Usage:** Similar to sigmoid, but with a larger output range.
|
58 |
+
|
59 |
+
::: columns
|
60 |
+
::: {.column width="50%"}
|
61 |
+
$$
|
62 |
+
\tanh(x) = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
|
63 |
+
$$
|
64 |
+
|
65 |
+
``` python
|
66 |
+
def tanh(x):
|
67 |
+
return np.tanh(x)
|
68 |
+
```
|
69 |
+
:::
|
70 |
+
|
71 |
+
::: {.column width="50%"}
|
72 |
+
{{< embed ActivationFunctions.ipynb#fig-tanh >}}
|
73 |
+
:::
|
74 |
+
:::
|
75 |
+
|
76 |
+
## Rectified Linear Unit (ReLU)
|
77 |
+
|
78 |
+
**Strengths:** Computationally efficient, non-saturating, and easy to compute.
|
79 |
+
|
80 |
+
**Weaknesses:** Not differentiable at x=0, which can cause issues during optimization.
|
81 |
+
|
82 |
+
**Usage:** Default activation function in many deep learning frameworks, suitable for most neural networks.
|
83 |
+
|
84 |
+
::: columns
|
85 |
+
::: {.column width="50%"}
|
86 |
+
$$
|
87 |
+
\text{ReLU}(x) = \max(0, x)
|
88 |
+
$$
|
89 |
+
|
90 |
+
``` python
|
91 |
+
def relu(x):
|
92 |
+
return np.maximum(0, x)
|
93 |
+
```
|
94 |
+
:::
|
95 |
+
|
96 |
+
::: {.column width="50%"}
|
97 |
+
{{< embed ActivationFunctions.ipynb#fig-relu >}}
|
98 |
+
:::
|
99 |
+
:::
|
100 |
+
|
101 |
+
## Leaky ReLU
|
102 |
+
|
103 |
+
**Strengths:** Similar to ReLU, but allows a small fraction of the input to pass through, helping with dying neurons.
|
104 |
+
|
105 |
+
**Weaknesses:** Still non-differentiable at x=0.
|
106 |
+
|
107 |
+
**Usage:** Alternative to ReLU, especially when dealing with dying neurons.
|
108 |
+
|
109 |
+
::: columns
|
110 |
+
::: {.column width="50%"}
|
111 |
+
$$
|
112 |
+
\text{Leaky ReLU}(x) =
|
113 |
+
\begin{cases}
|
114 |
+
x & \text{if } x > 0 \\
|
115 |
+
\alpha x & \text{if } x \leq 0
|
116 |
+
\end{cases}
|
117 |
+
$$
|
118 |
+
|
119 |
+
``` python
|
120 |
+
def leaky_relu(x, alpha=0.01):
|
121 |
+
# where α is a small constant (e.g., 0.01)
|
122 |
+
return np.where(x > 0, x, x * alpha)
|
123 |
+
```
|
124 |
+
:::
|
125 |
+
|
126 |
+
::: {.column width="50%"}
|
127 |
+
{{< embed ActivationFunctions.ipynb#fig-leaky_relu >}}
|
128 |
+
:::
|
129 |
+
:::
|
130 |
+
|
131 |
+
## Swish
|
132 |
+
|
133 |
+
**Strengths:** Self-gated, adaptive, and non-saturating.
|
134 |
+
|
135 |
+
**Weaknesses:** Computationally expensive, requires additional learnable parameters.
|
136 |
+
|
137 |
+
**Usage:** Can be used in place of ReLU or other activations, but may not always outperform them.
|
138 |
+
|
139 |
+
::: columns
|
140 |
+
::: {.column width="50%"}
|
141 |
+
$$
|
142 |
+
\text{Swish}(x) = x \cdot \sigma(x)
|
143 |
+
$$
|
144 |
+
|
145 |
+
``` python
|
146 |
+
def swish(x):
|
147 |
+
return x * sigmoid(x)
|
148 |
+
```
|
149 |
+
|
150 |
+
See also: [sigmoid](#sec-sigmoid)
|
151 |
+
:::
|
152 |
+
|
153 |
+
::: {.column width="50%"}
|
154 |
+
{{< embed ActivationFunctions.ipynb#fig-swish >}}
|
155 |
+
:::
|
156 |
+
:::
|
157 |
+
|
158 |
+
## Mish
|
159 |
+
|
160 |
+
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
161 |
+
|
162 |
+
**Weaknesses:** Not as well-studied as ReLU or other activations.
|
163 |
+
|
164 |
+
**Usage:** Alternative to ReLU, especially in computer vision tasks.
|
165 |
+
|
166 |
+
::: columns
|
167 |
+
::: {.column width="50%"}
|
168 |
+
$$
|
169 |
+
\text{Mish}(x) = x \cdot \tanh(\text{Softplus}(x))
|
170 |
+
$$
|
171 |
+
|
172 |
+
``` python
|
173 |
+
def mish(x):
|
174 |
+
return x * np.tanh(softplus(x))
|
175 |
+
```
|
176 |
+
:::
|
177 |
+
|
178 |
+
::: {.column width="50%"}
|
179 |
+
{{< embed ActivationFunctions.ipynb#fig-mish >}}
|
180 |
+
:::
|
181 |
+
:::
|
182 |
+
|
183 |
+
See also: [softplus](#softplus) [tanh](#sec-tanh)
|
184 |
+
|
185 |
+
## Softmax
|
186 |
+
|
187 |
+
**Strengths:** Normalizes output to ensure probabilities sum to 1, making it suitable for multi-class classification.
|
188 |
+
|
189 |
+
**Weaknesses:** Only suitable for output layers with multiple classes.
|
190 |
+
|
191 |
+
**Usage:** Output layer activation for multi-class classification problems.
|
192 |
+
|
193 |
+
::: columns
|
194 |
+
::: {.column width="50%"}
|
195 |
+
$$
|
196 |
+
\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{k=1}^{K} e^{x_k}}
|
197 |
+
$$
|
198 |
+
|
199 |
+
``` python
|
200 |
+
def softmax(x):
|
201 |
+
e_x = np.exp(x - np.max(x))
|
202 |
+
return e_x / e_x.sum()
|
203 |
+
```
|
204 |
+
:::
|
205 |
+
|
206 |
+
::: {.column width="50%"}
|
207 |
+
{{< embed ActivationFunctions.ipynb#fig-softmax >}}
|
208 |
+
:::
|
209 |
+
:::
|
210 |
+
|
211 |
+
## Softsign
|
212 |
+
|
213 |
+
**Strengths:** Similar to sigmoid, but with a more gradual slope.
|
214 |
+
|
215 |
+
**Weaknesses:** Not commonly used, may not provide significant benefits over sigmoid or tanh.
|
216 |
+
|
217 |
+
**Usage:** Alternative to sigmoid or tanh in certain situations.
|
218 |
+
|
219 |
+
::: columns
|
220 |
+
::: {.column width="50%"}
|
221 |
+
$$
|
222 |
+
\text{Softsign}(x) = \frac{x}{1 + |x|}
|
223 |
+
$$
|
224 |
+
|
225 |
+
``` python
|
226 |
+
def softsign(x):
|
227 |
+
return x / (1 + np.abs(x))
|
228 |
+
```
|
229 |
+
:::
|
230 |
+
|
231 |
+
::: {.column width="50%"}
|
232 |
+
{{< embed ActivationFunctions.ipynb#fig-softsign >}}
|
233 |
+
:::
|
234 |
+
:::
|
235 |
+
|
236 |
+
## SoftPlus {#softplus}
|
237 |
+
|
238 |
+
**Strengths:** Smooth, continuous, and non-saturating.
|
239 |
+
|
240 |
+
**Weaknesses:** Not commonly used, may not outperform other activations.
|
241 |
+
|
242 |
+
**Usage:** Experimental or niche applications.
|
243 |
+
|
244 |
+
::: columns
|
245 |
+
::: {.column width="50%"}
|
246 |
+
$$
|
247 |
+
\text{Softplus}(x) = \log(1 + e^x)
|
248 |
+
$$
|
249 |
+
|
250 |
+
``` python
|
251 |
+
def softplus(x):
|
252 |
+
return np.log1p(np.exp(x))
|
253 |
+
```
|
254 |
+
:::
|
255 |
+
|
256 |
+
::: {.column width="50%"}
|
257 |
+
{{< embed ActivationFunctions.ipynb#fig-softplus >}}
|
258 |
+
:::
|
259 |
+
:::
|
260 |
+
|
261 |
+
## ArcTan
|
262 |
+
|
263 |
+
**Strengths:** Non-saturating, smooth, and continuous.
|
264 |
+
|
265 |
+
**Weaknesses:** Not commonly used, may not outperform other activations.
|
266 |
+
|
267 |
+
**Usage:** Experimental or niche applications.
|
268 |
+
|
269 |
+
::: columns
|
270 |
+
::: {.column width="50%"}
|
271 |
+
$$
|
272 |
+
arctan(x) = arctan(x)
|
273 |
+
$$
|
274 |
+
|
275 |
+
``` python
|
276 |
+
def arctan(x):
|
277 |
+
return np.arctan(x)
|
278 |
+
```
|
279 |
+
:::
|
280 |
+
|
281 |
+
::: {.column width="50%"}
|
282 |
+
{{< embed ActivationFunctions.ipynb#fig-arctan >}}
|
283 |
+
:::
|
284 |
+
:::
|
285 |
+
|
286 |
+
## Gaussian Error Linear Unit (GELU)
|
287 |
+
|
288 |
+
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
289 |
+
|
290 |
+
**Weaknesses:** Not as well-studied as ReLU or other activations.
|
291 |
+
|
292 |
+
**Usage:** Alternative to ReLU, especially in Bayesian neural networks.
|
293 |
+
|
294 |
+
::: columns
|
295 |
+
::: {.column width="50%"}
|
296 |
+
$$
|
297 |
+
\text{GELU}(x) = x \cdot \Phi(x)
|
298 |
+
$$
|
299 |
+
|
300 |
+
``` python
|
301 |
+
def gelu(x):
|
302 |
+
return 0.5 * x
|
303 |
+
* (1 + np.tanh(np.sqrt(2 / np.pi)
|
304 |
+
* (x + 0.044715 * np.power(x, 3))))
|
305 |
+
```
|
306 |
+
:::
|
307 |
+
|
308 |
+
::: {.column width="50%"}
|
309 |
+
{{< embed ActivationFunctions.ipynb#fig-gelu >}}
|
310 |
+
:::
|
311 |
+
:::
|
312 |
+
|
313 |
+
See also: [tanh](#sec-tanh)
|
314 |
+
|
315 |
+
## Silu (SiLU)
|
316 |
+
|
317 |
+
$$
|
318 |
+
silu(x) = x * sigmoid(x)
|
319 |
+
$$
|
320 |
+
|
321 |
+
**Strengths:** Non-saturating, smooth, and computationally efficient.
|
322 |
+
|
323 |
+
**Weaknesses:** Not as well-studied as ReLU or other activations.
|
324 |
+
|
325 |
+
**Usage:** Alternative to ReLU, especially in computer vision tasks.
|
326 |
+
|
327 |
+
## GELU Approximation (GELU Approx.)
|
328 |
+
|
329 |
+
$$
|
330 |
+
f(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
|
331 |
+
$$
|
332 |
+
|
333 |
+
**Strengths:** Fast, non-saturating, and smooth.
|
334 |
+
|
335 |
+
**Weaknesses:** Approximation, not exactly equal to GELU.
|
336 |
+
|
337 |
+
**Usage:** Alternative to GELU, especially when computational efficiency is crucial.
|
338 |
+
|
339 |
+
## SELU (Scaled Exponential Linear Unit)
|
340 |
+
|
341 |
+
$$
|
342 |
+
f(x) = \lambda
|
343 |
+
\begin{cases}
|
344 |
+
x & x > 0 \\
|
345 |
+
\alpha e^x - \alpha & x \leq 0
|
346 |
+
\end{cases}
|
347 |
+
$$
|
348 |
+
|
349 |
+
**Strengths:** Self-normalizing, non-saturating, and computationally efficient.
|
350 |
+
|
351 |
+
**Weaknesses:** Requires careful initialization and α tuning.
|
352 |
+
|
353 |
+
**Usage:** Alternative to ReLU, especially in deep neural networks.
|
354 |
+
|
355 |
+
\listoffigures
|
src/theory/architectures.qmd
CHANGED
@@ -1,104 +1,763 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
* Description: A basic neural network architecture where data flows only in one direction, from input layer to output layer, without any feedback loops.
|
5 |
-
* Strengths: Simple to implement, computationally efficient
|
6 |
-
* Caveats: Limited capacity to model complex relationships, prone to overfitting
|
7 |
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
* Description: A neural network architecture that uses convolutional and pooling layers to extract features from images.
|
12 |
-
* Strengths: Excellent performance on image-related tasks, robust to image transformations
|
13 |
-
* Caveats: Computationally expensive, require large datasets
|
14 |
|
15 |
-
|
16 |
|
17 |
-
* Usage
|
18 |
-
*
|
19 |
-
*
|
20 |
-
* Caveats: Suffer from vanishing gradients, difficult to train
|
21 |
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
* Description: A neural network architecture that uses self-attention mechanisms to model relationships between input sequences.
|
33 |
-
* Strengths: Excellent performance on sequential data, parallelizable, can handle long-range dependencies
|
34 |
-
* Caveats: Computationally expensive, require large datasets
|
35 |
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
* Description: A neural network architecture that learns to compress and reconstruct input data.
|
40 |
-
* Strengths: Excellent performance on dimensionality reduction, can learn robust representations
|
41 |
-
* Caveats: May not perform well on complex data distributions
|
42 |
|
43 |
-
|
44 |
|
45 |
-
|
46 |
-
* Description: A neural network architecture that consists of a generator and discriminator, which compete to generate realistic data.
|
47 |
-
* Strengths: Excellent performance on generative tasks, can generate realistic data
|
48 |
-
* Caveats: Training can be unstable, require careful tuning of hyperparameters
|
49 |
|
50 |
-
|
51 |
|
52 |
-
* Usage
|
53 |
-
*
|
54 |
-
*
|
55 |
-
* Caveats: May not perform well on sequential data
|
56 |
|
57 |
-
## **9. U-Net**
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
|
73 |
-
|
74 |
-
* Description: A neural network architecture that uses graph structures to model relationships between nodes.
|
75 |
-
* Strengths: Excellent performance on graph-based data, can model complex relationships
|
76 |
-
* Caveats: Computationally expensive, require large datasets
|
77 |
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
* Description: A neural network architecture that uses reinforcement learning to learn from interactions with an environment.
|
82 |
-
* Strengths: Excellent performance on sequential decision-making tasks, can learn complex policies
|
83 |
-
* Caveats: Require large datasets, can be slow to train
|
84 |
|
85 |
-
##
|
86 |
|
87 |
-
|
88 |
-
* Description: A neural network architecture that uses evolutionary principles to evolve neural networks.
|
89 |
-
* Strengths: Excellent performance on optimization problems, can learn complex policies
|
90 |
-
* Caveats: Computationally expensive, require large datasets
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Network Architectures
|
3 |
+
---
|
4 |
|
5 |
+
Neural network architectures are foundational frameworks designed to tackle diverse problems in artificial intelligence and machine learning. Each architecture is structured to optimize learning and performance for specific types of data and tasks, ranging from simple classification problems to complex sequence generation challenges. This guide explores the various architectures employed in neural networks, providing insights into how they are constructed, their applications, and why certain architectures are preferred for particular tasks.
|
|
|
|
|
|
|
6 |
|
7 |
+
The architecture of a neural network dictates how information flows and is processed. It determines the arrangement and connectivity of layers, the type of data processing that occurs, and how input data is ultimately transformed into outputs. The choice of a suitable architecture is crucial because it impacts the efficiency, accuracy, and feasibility of training models on given datasets.
|
8 |
|
9 |
+
## Feedforward Neural Networks (FNNs)
|
|
|
|
|
|
|
10 |
|
11 |
+
A basic neural network architecture where data flows only in one direction, from input layer to output layer, without any feedback loops. Feedforward Neural Networks are the simplest type of neural network architecture where connections between the nodes do not form a cycle. This is ideal for problems where the output is directly mapped from the input.
|
12 |
|
13 |
+
* **Usage**: Image classification, regression, function approximation
|
14 |
+
* **Strengths**: Simple to implement, computationally efficient
|
15 |
+
* **Caveats**: Limited capacity to model complex relationships, prone to overfitting
|
|
|
16 |
|
17 |
+
```{dot}
|
18 |
+
digraph FNN {
|
19 |
+
// Graph properties
|
20 |
+
node [shape=record];
|
21 |
|
22 |
+
// Nodes definitions
|
23 |
+
input_layer [label="Input Layer" shape=ellipse];
|
24 |
+
hidden_layer1 [label="Hidden Layer 1" shape=box];
|
25 |
+
hidden_layer2 [label="Hidden Layer 2" shape=box];
|
26 |
+
output_layer [label="Output Layer" shape=ellipse];
|
27 |
|
28 |
+
// Edges definitions
|
29 |
+
input_layer -> hidden_layer1;
|
30 |
+
hidden_layer1 -> hidden_layer2;
|
31 |
+
hidden_layer2 -> output_layer;
|
32 |
+
}
|
33 |
+
```
|
34 |
|
35 |
+
**Input Layer**: This layer represents the initial data that is fed into the network. Each node in this layer typically corresponds to a feature in the input dataset.
|
|
|
|
|
|
|
36 |
|
37 |
+
**Hidden Layers**: These are intermediary layers between the input and output layers. Hidden layers allow the network to learn complex patterns in the data. They are called "hidden" because they are not directly exposed to the input or output.
|
38 |
|
39 |
+
**Output Layer**: The final layer that produces the network’s predictions. The function of this layer can vary depending on the specific application — for example, it might use a softmax activation function for classification tasks or a linear activation for regression tasks.
|
|
|
|
|
|
|
40 |
|
41 |
+
**Edges**: Represent the connections between neurons in consecutive layers. In feedforward networks, every neuron in one layer connects to every neuron in the next layer. These connections are weighted, and these weights are adjusted during training to minimize error.
|
42 |
|
43 |
+
## Convolutional Neural Networks (CNNs)
|
|
|
|
|
|
|
44 |
|
45 |
+
A neural network architecture that uses convolutional and pooling layers to extract features from images. CNNs are highly effective at processing data that has a grid-like topology, such as images, due to their ability to exploit spatial hierarchies and structures within the data.
|
46 |
|
47 |
+
* **Usage**: Image classification, object detection, image segmentation
|
48 |
+
* **Strengths**: Excellent performance on image-related tasks, robust to image transformations
|
49 |
+
* **Caveats**: Computationally expensive, require large datasets
|
|
|
50 |
|
|
|
51 |
|
52 |
+
```{dot}
|
53 |
+
digraph CNN {
|
54 |
+
// Graph properties
|
55 |
+
node [shape=record];
|
56 |
|
57 |
+
// Nodes definitions
|
58 |
+
input_image [label="Input Image" shape=ellipse];
|
59 |
+
conv1 [label="Convolution Layer 1\nReLU" shape=box];
|
60 |
+
pool1 [label="Pooling Layer 1" shape=box];
|
61 |
+
conv2 [label="Convolution Layer 2\nReLU" shape=box];
|
62 |
+
pool2 [label="Pooling Layer 2" shape=box];
|
63 |
+
fully_connected [label="Fully Connected Layer" shape=box];
|
64 |
+
output [label="Output\n(Classification)" shape=ellipse];
|
65 |
|
66 |
+
// Edges definitions
|
67 |
+
input_image -> conv1;
|
68 |
+
conv1 -> pool1;
|
69 |
+
pool1 -> conv2;
|
70 |
+
conv2 -> pool2;
|
71 |
+
pool2 -> fully_connected;
|
72 |
+
fully_connected -> output;
|
73 |
+
}
|
74 |
+
```
|
75 |
+
**Input Image**: The initial input where images are fed into the network.
|
76 |
|
77 |
+
**Convolution Layer 1 and 2**: These layers apply a set of filters to the input image to create feature maps. These filters are designed to detect spatial hierarchies such as edges, colors, gradients, and more complex patterns as the network deepens. Each convolution layer is typically followed by a non-linear activation function like ReLU (Rectified Linear Unit).
|
78 |
|
79 |
+
**Pooling Layer 1 and 2**: These layers reduce the spatial size of the feature maps to decrease the amount of computation and weights in the network. Pooling (often max pooling) helps make the detection of features invariant to scale and orientation changes.
|
|
|
|
|
|
|
80 |
|
81 |
+
**Fully Connected Layer**: This layer takes the flattened output of the last pooling layer and performs classification based on the features extracted by the convolutional and pooling layers.
|
82 |
|
83 |
+
**Output**: The final output layer, which classifies the input image into categories based on the training dataset.
|
|
|
|
|
|
|
84 |
|
85 |
+
## Recurrent Neural Networks (RNNs)
|
86 |
|
87 |
+
A neural network architecture that uses feedback connections to model sequential data. RNNs are capable of processing sequences of data by maintaining a state that acts as a memory. They are particularly useful for applications where the context or sequence of data points is important.
|
|
|
|
|
|
|
88 |
|
89 |
+
* **Usage**: Natural Language Processing (NLP), sequence prediction, time series forecasting
|
90 |
+
* **Strengths**: Excellent performance on sequential data, can model long-term dependencies
|
91 |
+
* **Caveats**: Suffer from vanishing gradients, difficult to train
|
92 |
+
```{dot}
|
93 |
+
digraph RNN {
|
94 |
+
// Graph properties
|
95 |
+
node [shape=record];
|
96 |
|
97 |
+
// Nodes definitions
|
98 |
+
input_seq [label="Input Sequence" shape=ellipse];
|
99 |
+
rnn_cell [label="RNN Cell" shape=box];
|
100 |
+
hidden_state [label="Hidden State" shape=box];
|
101 |
+
output_seq [label="Output Sequence" shape=ellipse];
|
102 |
|
103 |
+
// Edges definitions
|
104 |
+
input_seq -> rnn_cell;
|
105 |
+
rnn_cell -> hidden_state [label="Next"];
|
106 |
+
hidden_state -> rnn_cell [label="Recurrence", dir=back];
|
107 |
+
rnn_cell -> output_seq;
|
108 |
|
109 |
+
// Additional details for clarity
|
110 |
+
edge [style=dashed];
|
111 |
+
rnn_cell -> output_seq [label="Each timestep", style=dashed];
|
112 |
+
}
|
113 |
+
```
|
114 |
+
|
115 |
+
**Input Sequence**: Represents the sequence of data being fed into the RNN, such as a sentence or time-series data.
|
116 |
+
|
117 |
+
**RNN Cell**: This is the core of an RNN, where the computation happens. It takes input from the current element of the sequence and combines it with the hidden state from the previous element of the sequence.
|
118 |
+
|
119 |
+
**Hidden State**: This node represents the memory of the network, carrying information from one element of the sequence to the next. The hidden state is updated continuously as the sequence is processed.
|
120 |
+
|
121 |
+
**Output Sequence**: The RNN can produce an output at each timestep, depending on the task. For example, in sequence labeling, there might be an output corresponding to each input timestep.
|
122 |
+
|
123 |
+
|
124 |
+
## Long Short-Term Memory (LSTM) Networks
|
125 |
+
|
126 |
+
A type of RNN that uses memory cells to learn long-term dependencies. LSTM networks are designed to avoid the long-term dependency problem, making them effective at tasks where the context can extend over longer sequences.
|
127 |
+
|
128 |
+
* **Usage**: NLP, sequence prediction, time series forecasting
|
129 |
+
* **Strengths**: Excellent performance on sequential data, can model long-term dependencies
|
130 |
+
* **Caveats**: Computationally expensive, require large datasets
|
131 |
+
|
132 |
+
```{dot}
|
133 |
+
digraph LSTM {
|
134 |
+
// Graph properties
|
135 |
+
node [shape=record];
|
136 |
+
|
137 |
+
// Nodes definitions
|
138 |
+
input_seq [label="Input Sequence" shape=ellipse];
|
139 |
+
lstm_cell [label="LSTM Cell" shape=box];
|
140 |
+
cell_state [label="Cell State" shape=box];
|
141 |
+
hidden_state [label="Hidden State" shape=box];
|
142 |
+
output_seq [label="Output Sequence" shape=ellipse];
|
143 |
+
|
144 |
+
// Edges definitions
|
145 |
+
input_seq -> lstm_cell;
|
146 |
+
cell_state -> lstm_cell [label="Recurrence", dir=back];
|
147 |
+
hidden_state -> lstm_cell [label="Recurrence", dir=back];
|
148 |
+
lstm_cell -> cell_state [label="Update"];
|
149 |
+
lstm_cell -> hidden_state [label="Update"];
|
150 |
+
lstm_cell -> output_seq;
|
151 |
+
|
152 |
+
// Additional explanations
|
153 |
+
edge [style=dashed];
|
154 |
+
lstm_cell -> output_seq [label="Each timestep", style=dashed];
|
155 |
+
}
|
156 |
+
```
|
157 |
+
|
158 |
+
**Input Sequence**: Represents the sequential data input, such as a series of words or time-series data points.
|
159 |
+
|
160 |
+
**LSTM Cell**: The core unit in an LSTM network that processes input data one element at a time. It interacts intricately with both the cell state and the hidden state to manage and preserve information over long periods.
|
161 |
+
|
162 |
+
**Cell State**: A "long-term" memory component of the LSTM cell. It carries relevant information throughout the processing of the sequence, with the ability to add or remove information via gates (not explicitly shown here).
|
163 |
+
|
164 |
+
**Hidden State**: A "short-term" memory component that also transfers information to the next time step but is more sensitive and responsive to recent inputs than the cell state.
|
165 |
+
|
166 |
+
**Output Sequence**: Depending on the task, LSTMs can output at each timestep (for tasks like sequence labeling) or after processing the entire sequence (like sentiment analysis).
|
167 |
+
|
168 |
+
## Transformers
|
169 |
+
|
170 |
+
A neural network architecture that uses self-attention mechanisms to model relationships between input sequences. Transformers are particularly effective in NLP tasks due to their ability to handle sequences in parallel and consider all parts of the input at once.
|
171 |
+
|
172 |
+
* **Usage**: NLP, machine translation, language modeling
|
173 |
+
* **Strengths**: Excellent performance on sequential data, parallelizable, can handle long-range dependencies
|
174 |
+
* **Caveats**: Computationally expensive, require large datasets
|
175 |
+
```{dot}
|
176 |
+
digraph Transformer {
|
177 |
+
// Graph properties
|
178 |
+
node [shape=record];
|
179 |
+
|
180 |
+
// Nodes definitions
|
181 |
+
input_tokens [label="Input Tokens" shape=ellipse];
|
182 |
+
embedding_layer [label="Embedding Layer" shape=box];
|
183 |
+
positional_encoding [label="Add Positional Encoding" shape=box];
|
184 |
+
encoder [label="Encoder Stack" shape=box];
|
185 |
+
decoder [label="Decoder Stack" shape=box];
|
186 |
+
output_tokens [label="Output Tokens" shape=ellipse];
|
187 |
+
|
188 |
+
// Edges definitions
|
189 |
+
input_tokens -> embedding_layer;
|
190 |
+
embedding_layer -> positional_encoding;
|
191 |
+
positional_encoding -> encoder;
|
192 |
+
encoder -> decoder;
|
193 |
+
decoder -> output_tokens;
|
194 |
+
|
195 |
+
// Additional components for clarity (not actual flow)
|
196 |
+
encoder_output [label="Encoder Output" shape=note];
|
197 |
+
decoder_input [label="Decoder Input" shape=note];
|
198 |
+
encoder -> encoder_output [style=dashed];
|
199 |
+
decoder_input -> decoder [style=dashed];
|
200 |
+
|
201 |
+
// Descriptions for self-attention and cross-attention mechanisms
|
202 |
+
self_attention [label="Self-Attention" shape=plaintext];
|
203 |
+
cross_attention [label="Cross-Attention" shape=plaintext];
|
204 |
+
encoder -> self_attention [style=dotted];
|
205 |
+
decoder -> self_attention [style=dotted];
|
206 |
+
decoder -> cross_attention [style=dotted];
|
207 |
+
cross_attention -> encoder_output [style=dotted, dir=none];
|
208 |
+
}
|
209 |
+
```
|
210 |
+
|
211 |
+
**Input Tokens**: Represents the initial sequence of tokens (e.g., words in a sentence) that are fed into the Transformer.
|
212 |
+
|
213 |
+
**Embedding Layer**: Converts tokens into vectors that the model can process. Each token is mapped to a unique vector.
|
214 |
+
|
215 |
+
**Positional Encoding**: Adds information about the position of each token in the sequence to the embeddings, which is crucial as Transformers do not inherently process sequential data.
|
216 |
+
|
217 |
+
**Encoder Stack**: A series of encoder layers that process the input. Each layer uses self-attention mechanisms to consider all parts of the input simultaneously.
|
218 |
+
|
219 |
+
**Decoder Stack**: A series of decoder layers that generate the output sequence step by step. Each layer uses both self-attention mechanisms to attend to its own output so far, and cross-attention mechanisms to focus on the output from the encoder.
|
220 |
+
|
221 |
+
**Output Tokens**: The final output sequence generated by the Transformer, such as a translated sentence or the continuation of an input text.
|
222 |
+
|
223 |
+
**Encoder Output and Decoder Input**: Not actual data flow, but illustrate how information is transferred from the encoder to the decoder.
|
224 |
+
|
225 |
+
**Self-Attention and Cross-Attention**: These mechanisms are core features of Transformer models. Self-attention allows layers to consider other parts of the input or output at each step, while cross-attention allows the decoder to focus on relevant parts of the input sequence.
|
226 |
+
|
227 |
+
## Autoencoders
|
228 |
+
|
229 |
+
A neural network architecture that learns to compress and reconstruct input data. Autoencoders are typically used for dimensionality reduction tasks, as they learn to encode the essential aspects of the data in a smaller representation.
|
230 |
+
|
231 |
+
* **Usage**: Dimensionality reduction, anomaly detection, generative modeling
|
232 |
+
* **Strengths**: Excellent performance on dimensionality reduction, can learn robust representations
|
233 |
+
* **Caveats**: May not perform well on complex data distributions
|
234 |
+
```{dot}
|
235 |
+
digraph Autoencoder {
|
236 |
+
// Graph properties
|
237 |
+
node [shape=record];
|
238 |
+
|
239 |
+
// Nodes definitions
|
240 |
+
input_data [label="Input Data" shape=ellipse];
|
241 |
+
encoder [label="Encoder" shape=box];
|
242 |
+
latent_space [label="Latent Space" shape=box];
|
243 |
+
decoder [label="Decoder" shape=box];
|
244 |
+
reconstructed_output [label="Reconstructed Output" shape=ellipse];
|
245 |
+
|
246 |
+
// Edges definitions
|
247 |
+
input_data -> encoder;
|
248 |
+
encoder -> latent_space;
|
249 |
+
latent_space -> decoder;
|
250 |
+
decoder -> reconstructed_output;
|
251 |
+
}
|
252 |
+
```
|
253 |
+
|
254 |
+
**Input Data**: Represents the data that is fed into the Autoencoder. This could be any kind of data, such as images, text, or sound.
|
255 |
+
|
256 |
+
**Encoder**: The first part of the Autoencoder that processes the input data and compresses it into a smaller, dense representation. This part typically consists of several layers that gradually reduce the dimensionality of the input.
|
257 |
+
|
258 |
+
**Latent Space**: Also known as the "encoded" state or "bottleneck". This is a lower-dimensional representation of the input data and serves as the compressed "code" that the decoder will use to reconstruct the input.
|
259 |
+
|
260 |
+
**Decoder**: Mirrors the structure of the encoder but in reverse. It takes the encoded data from the latent space and reconstructs the original data as closely as possible. This part typically consists of layers that gradually increase in dimensionality to match the original input size.
|
261 |
+
|
262 |
+
**Reconstructed Output**: The final output of the Autoencoder. This is the reconstruction of the original input data based on the compressed code stored in the latent space. The quality of this reconstruction is often a measure of the Autoencoder’s performance.
|
263 |
+
|
264 |
+
## Generative Adversarial Networks (GANs)
|
265 |
+
|
266 |
+
A neural network architecture that consists of a generator and discriminator, which compete to generate realistic data. GANs are highly effective at generating new data that mimics the input data, often used in image generation and editing.
|
267 |
+
|
268 |
+
* **Usage**: Generative modeling, data augmentation, style transfer
|
269 |
+
* **Strengths**: Excellent performance on generative tasks, can generate realistic data
|
270 |
+
* **Caveats**: Training can be unstable, require careful tuning of hyperparameters
|
271 |
+
```{dot}
|
272 |
+
digraph GAN {
|
273 |
+
// Graph properties
|
274 |
+
node [shape=record];
|
275 |
+
|
276 |
+
// Nodes definitions
|
277 |
+
noise [label="Noise vector (z)" shape=ellipse];
|
278 |
+
generator [label="Generator (G)" shape=box];
|
279 |
+
generated_image [label="Generated image (G(z))" shape=cds];
|
280 |
+
real_image [label="Real image (x)" shape=cds];
|
281 |
+
discriminator [label="Discriminator (D)" shape=box];
|
282 |
+
D_output_fake [label="D(G(z))" shape=ellipse];
|
283 |
+
D_output_real [label="D(x)" shape=ellipse];
|
284 |
+
|
285 |
+
// Edges definitions
|
286 |
+
noise -> generator;
|
287 |
+
generator -> generated_image;
|
288 |
+
generated_image -> discriminator [label="Fake"];
|
289 |
+
real_image -> discriminator [label="Real"];
|
290 |
+
discriminator -> D_output_fake [label="Output for fake"];
|
291 |
+
discriminator -> D_output_real [label="Output for real"];
|
292 |
+
}
|
293 |
+
|
294 |
+
```
|
295 |
+
|
296 |
+
|
297 |
+
**Noise vector (z)**: Represents the random noise input to the generator.
|
298 |
+
|
299 |
+
**Generator (G)**: The model that learns to generate new data with the same statistics as the training set from the noise vector.
|
300 |
+
|
301 |
+
**Generated image (G(z))**: The fake data produced by the generator.
|
302 |
+
|
303 |
+
**Real image (x)**: Actual data samples from the training dataset.
|
304 |
+
|
305 |
+
**Discriminator (D)**: The model that learns to distinguish between real data and synthetic data generated by the Generator.
|
306 |
+
|
307 |
+
**D(G(z)) and D(x)**: Outputs of the Discriminator when evaluating fake data and real data, respectively.
|
308 |
+
|
309 |
+
|
310 |
+
The Noise vector feeds into the Generator.
|
311 |
+
|
312 |
+
The Generator outputs a Generated image, which is input to the Discriminator labeled as "Fake".
|
313 |
+
|
314 |
+
The Real image also feeds into the Discriminator but is labeled as "Real".
|
315 |
+
|
316 |
+
The Discriminator outputs evaluations for both fake and real inputs.
|
317 |
+
|
318 |
+
## Residual Networks (ResNets)
|
319 |
+
|
320 |
+
A neural network architecture that uses residual connections to ease training. ResNets are particularly effective for very deep networks, as they allow for training deeper networks by providing pathways for gradients to flow through.
|
321 |
+
|
322 |
+
* **Usage**: Image classification, object detection
|
323 |
+
* **Strengths**: Excellent performance on image-related tasks, ease of training
|
324 |
+
* **Caveats**: May not perform well on sequential data
|
325 |
+
```{dot}
|
326 |
+
digraph ResNet {
|
327 |
+
// Graph properties
|
328 |
+
node [shape=record];
|
329 |
+
|
330 |
+
// Nodes definitions
|
331 |
+
input [label="Input Image" shape=ellipse];
|
332 |
+
conv1 [label="Initial Conv + BN + ReLU" shape=box];
|
333 |
+
resblock1 [label="<f0> ResBlock | <f1> + | <f2> ReLU" shape=Mrecord];
|
334 |
+
resblock2 [label="<f0> ResBlock | <f1> + | <f2> ReLU" shape=Mrecord];
|
335 |
+
resblock3 [label="<f0> ResBlock | <f1> + | <f2> ReLU" shape=Mrecord];
|
336 |
+
avgpool [label="Average Pooling" shape=box];
|
337 |
+
fc [label="Fully Connected Layer" shape=box];
|
338 |
+
output [label="Output" shape=ellipse];
|
339 |
+
|
340 |
+
// Edges definitions
|
341 |
+
input -> conv1;
|
342 |
+
conv1 -> resblock1:f0;
|
343 |
+
resblock1:f2 -> resblock2:f0;
|
344 |
+
resblock2:f2 -> resblock3:f0;
|
345 |
+
resblock3:f2 -> avgpool;
|
346 |
+
avgpool -> fc;
|
347 |
+
fc -> output;
|
348 |
+
|
349 |
+
// Adding skip connections
|
350 |
+
edge [style=dashed];
|
351 |
+
conv1 -> resblock1:f1;
|
352 |
+
resblock1:f1 -> resblock2:f1;
|
353 |
+
resblock2:f1 -> resblock3:f1;
|
354 |
+
}
|
355 |
+
```
|
356 |
+
**Input Image**: The initial input layer where images are fed into the network.
|
357 |
+
|
358 |
+
**Initial Conv + BN + ReLU**: Represents an initial convolutional layer followed by batch normalization and a ReLU activation function to prepare the data for residual blocks.
|
359 |
+
|
360 |
+
**ResBlock**: These are the residual blocks that define the ResNet architecture. Each block contains two parts: a sequence of convolutional layers and a skip connection that adds the input of the block to its output.
|
361 |
+
|
362 |
+
**Average Pooling**: This layer averages the feature maps spatially to reduce their dimensions before passing to a fully connected layer.
|
363 |
+
|
364 |
+
**Fully Connected Layer**: This layer maps the feature representations to the final output classes.
|
365 |
+
|
366 |
+
**Output**: The final prediction of the network.
|
367 |
+
|
368 |
+
|
369 |
+
## U-Net
|
370 |
+
|
371 |
+
A neural network architecture that uses an encoder-decoder structure with skip connections. U-Net is designed primarily for biomedical image segmentation, where it is crucial to localize objects precisely within an image.
|
372 |
+
|
373 |
+
* **Usage**: Image segmentation, object detection
|
374 |
+
* **Strengths**: Excellent performance on image segmentation tasks, fast training
|
375 |
+
* **Caveats**: May not perform well on sequential data
|
376 |
+
```{dot}
|
377 |
+
digraph UNet {
|
378 |
+
// Graph properties
|
379 |
+
node [shape=record];
|
380 |
+
|
381 |
+
// Nodes definitions
|
382 |
+
input [label="Input Image" shape=ellipse];
|
383 |
+
conv1 [label="Conv + ReLU\nDownsampling" shape=box];
|
384 |
+
conv2 [label="Conv + ReLU\nDownsampling" shape=box];
|
385 |
+
bottom [label="Conv + ReLU" shape=box];
|
386 |
+
upconv1 [label="UpConv + ReLU\nUpsampling" shape=box];
|
387 |
+
concat1 [label="Concatenate" shape=circle];
|
388 |
+
upconv2 [label="UpConv + ReLU\nUpsampling" shape=box];
|
389 |
+
concat2 [label="Concatenate" shape=circle];
|
390 |
+
finalconv [label="Conv + ReLU\n1x1 Conv" shape=box];
|
391 |
+
output [label="Output\nSegmentation Map" shape=ellipse];
|
392 |
+
|
393 |
+
// Edges definitions
|
394 |
+
input -> conv1;
|
395 |
+
conv1 -> conv2;
|
396 |
+
conv2 -> bottom;
|
397 |
+
bottom -> upconv1;
|
398 |
+
upconv1 -> concat1;
|
399 |
+
concat1 -> upconv2;
|
400 |
+
upconv2 -> concat2;
|
401 |
+
concat2 -> finalconv;
|
402 |
+
finalconv -> output;
|
403 |
+
|
404 |
+
// Skip connections
|
405 |
+
edge [style=dashed];
|
406 |
+
conv1 -> concat1 [label="Copy\ncrop"];
|
407 |
+
conv2 -> concat2 [label="Copy\ncrop"];
|
408 |
+
}
|
409 |
+
```
|
410 |
+
|
411 |
+
**Input Image**: The initial input layer where images are fed into the network.
|
412 |
+
|
413 |
+
**Conv + ReLU / Downsampling**: These blocks represent convolutional operations followed by a ReLU activation function. The "Downsampling" indicates that each block reduces the spatial dimensions of the input.
|
414 |
+
|
415 |
+
**Bottom**: This is the lowest part of the U, consisting of convolutional layers without downsampling, positioned before the upsampling starts.
|
416 |
+
|
417 |
+
**UpConv + ReLU / Upsampling**: These blocks perform transposed convolutions (or up-convolutions) that increase the resolution of the feature maps.
|
418 |
+
|
419 |
+
**Concatenate**: These layers concatenate feature maps from the downsampling pathway with the upsampled feature maps to preserve high-resolution features for precise localization.
|
420 |
+
|
421 |
+
**Final Conv**: This typically includes a 1x1 convolution to map the deep feature representations to the desired number of classes for segmentation.
|
422 |
+
|
423 |
+
**Output / Segmentation Map**: The final output layer which produces the segmented image.
|
424 |
+
|
425 |
+
## Attention-based Models
|
426 |
+
|
427 |
+
A neural network architecture that uses attention mechanisms to focus on relevant input regions. Attention-based models are particularly effective for tasks that require understanding of complex relationships within the data, such as interpreting a document or translating a sentence.
|
428 |
+
|
429 |
+
* **Usage**: NLP, machine translation, question answering
|
430 |
+
* **Strengths**: Excellent performance on sequential data, can model long-range dependencies
|
431 |
+
* **Caveats**: Require careful tuning of hyperparameters
|
432 |
+
```{dot}
|
433 |
+
digraph AttentionBasedModels {
|
434 |
+
// Graph properties
|
435 |
+
node [shape=record];
|
436 |
+
|
437 |
+
// Nodes definitions
|
438 |
+
input [label="Input Sequence" shape=ellipse];
|
439 |
+
embedding [label="Embedding Layer" shape=box];
|
440 |
+
positional [label="Add Positional Encoding" shape=box];
|
441 |
+
multihead [label="Multi-Head Attention" shape=box];
|
442 |
+
addnorm1 [label="Add & Norm" shape=box];
|
443 |
+
feedforward [label="Feedforward Network" shape=box];
|
444 |
+
addnorm2 [label="Add & Norm" shape=box];
|
445 |
+
output [label="Output Sequence" shape=ellipse];
|
446 |
+
|
447 |
+
// Edges definitions
|
448 |
+
input -> embedding;
|
449 |
+
embedding -> positional;
|
450 |
+
positional -> multihead;
|
451 |
+
multihead -> addnorm1;
|
452 |
+
addnorm1 -> feedforward;
|
453 |
+
feedforward -> addnorm2;
|
454 |
+
addnorm2 -> output;
|
455 |
+
|
456 |
+
// Skip connections
|
457 |
+
edge [style=dashed];
|
458 |
+
positional -> addnorm1 [label="Skip Connection"];
|
459 |
+
addnorm1 -> addnorm2 [label="Skip Connection"];
|
460 |
+
}
|
461 |
+
|
462 |
+
```
|
463 |
+
**Input Sequence**: Initial data input, typically a sequence of tokens.
|
464 |
+
|
465 |
+
**Embedding Layer**: Converts tokens into vectors that the model can process.
|
466 |
+
|
467 |
+
**Add Positional Encoding**: Incorporates information about the position of tokens in the sequence into their embeddings, which is crucial since attention mechanisms do not inherently process sequential data.
|
468 |
+
|
469 |
+
**Multi-Head Attention**: Allows the model to focus on different parts of the sequence for different representations, facilitating better understanding and processing of the input.
|
470 |
+
|
471 |
+
**Add & Norm**: A layer that combines residuals (from skip connections) with the output of the attention or feedforward layers, followed by layer normalization.
|
472 |
+
|
473 |
+
**Feedforward Network**: A dense neural network that processes the sequence after attention has been applied.
|
474 |
+
|
475 |
+
**Output Sequence**: The final processed sequence output by the model, often used for tasks like translation, text generation, or classification.
|
476 |
+
|
477 |
+
**Skip Connections**: Dashed lines represent skip connections that help to alleviate the vanishing gradient problem by allowing gradients to flow through the network directly. They also help the model to learn an identity function which ensures that the model does not lose information throughout the layers.
|
478 |
+
|
479 |
+
## Graph Neural Networks (GNNs)
|
480 |
+
|
481 |
+
A neural network architecture that uses graph structures to model relationships between nodes. GNNs are effective for data that can be represented as graphs, such as social networks or molecules, as they capture the relationships between entities.
|
482 |
+
|
483 |
+
* **Usage**: Graph-based data, social network analysis, recommendation systems
|
484 |
+
* **Strengths**: Excellent performance on graph-based data, can model complex relationships
|
485 |
+
* **Caveats**: Computationally expensive, require large datasets
|
486 |
+
```{dot}
|
487 |
+
digraph GNN {
|
488 |
+
// Graph properties
|
489 |
+
node [shape=record];
|
490 |
+
|
491 |
+
// Nodes definitions
|
492 |
+
input_graph [label="Input Graph" shape=ellipse];
|
493 |
+
node_features [label="Node Features" shape=box];
|
494 |
+
edge_features [label="Edge Features" shape=box];
|
495 |
+
gnn_layers [label="GNN Layers" shape=box];
|
496 |
+
aggregate [label="Aggregate Messages" shape=box];
|
497 |
+
update [label="Update States" shape=box];
|
498 |
+
readout [label="Graph-level Readout" shape=box];
|
499 |
+
output [label="Output" shape=ellipse];
|
500 |
+
|
501 |
+
// Edges definitions
|
502 |
+
input_graph -> node_features;
|
503 |
+
input_graph -> edge_features;
|
504 |
+
node_features -> gnn_layers;
|
505 |
+
edge_features -> gnn_layers;
|
506 |
+
gnn_layers -> aggregate;
|
507 |
+
aggregate -> update;
|
508 |
+
update -> readout;
|
509 |
+
readout -> output;
|
510 |
+
}
|
511 |
+
```
|
512 |
+
|
513 |
+
**Input Graph**: The initial graph input containing nodes and edges.
|
514 |
+
|
515 |
+
**Node Features**: Processes the features associated with each node. These can include node labels, attributes, or other data.
|
516 |
+
|
517 |
+
**Edge Features**: Processes features associated with edges in the graph, which might include types of relationships, weights, or other characteristics.
|
518 |
+
|
519 |
+
**GNN Layers**: A series of graph neural network layers that apply convolution-like operations over the graph. These layers can involve message passing between nodes, where a node's new state is determined based on its neighbors.
|
520 |
+
|
521 |
+
**Aggregate Messages**: Combines the information (messages) received from neighboring nodes into a single unified message. Aggregation functions can include sums, averages, or max operations.
|
522 |
+
|
523 |
+
**Update States**: Updates the states of the nodes based on aggregated messages, typically using some form of neural network or transformation.
|
524 |
+
|
525 |
+
**Graph-level Readout**: Aggregates node states into a graph-level representation, which can be used for tasks that require a holistic view of the graph (e.g., determining the properties of a molecule).
|
526 |
+
|
527 |
+
**Output**: The final output, which can vary depending on the specific application (node classification, link prediction, graph classification, etc.).
|
528 |
+
|
529 |
+
## Reinforcement Learning (RL) Architectures
|
530 |
+
|
531 |
+
A neural network architecture that uses reinforcement learning to learn from interactions with an environment. RL architectures are highly effective for sequential decision-making tasks, such as playing games or navigating environments.
|
532 |
+
|
533 |
+
* **Usage**: Game playing, robotics, autonomous systems
|
534 |
+
* **Strengths**: Excellent performance on sequential decision-making tasks, can learn complex policies
|
535 |
+
* **Caveats**: Require large datasets, can be slow to train
|
536 |
+
```{dot}
|
537 |
+
digraph RL {
|
538 |
+
// Graph properties
|
539 |
+
node [shape=record];
|
540 |
+
|
541 |
+
// Nodes definitions
|
542 |
+
environment [label="Environment" shape=ellipse];
|
543 |
+
state [label="State" shape=ellipse];
|
544 |
+
agent [label="Agent" shape=box];
|
545 |
+
action [label="Action" shape=ellipse];
|
546 |
+
reward [label="Reward" shape=ellipse];
|
547 |
+
updated_state [label="Updated State" shape=ellipse];
|
548 |
+
|
549 |
+
// Edges definitions
|
550 |
+
environment -> state;
|
551 |
+
state -> agent;
|
552 |
+
agent -> action;
|
553 |
+
action -> environment;
|
554 |
+
environment -> reward;
|
555 |
+
reward -> agent;
|
556 |
+
environment -> updated_state [label="Feedback Loop"];
|
557 |
+
updated_state -> state [label="New State"];
|
558 |
+
}
|
559 |
+
```
|
560 |
+
|
561 |
+
**Environment**: This is where the agent operates. It defines the dynamics of the system including how the states transition and how rewards are assigned for actions.
|
562 |
+
|
563 |
+
**State**: Represents the current situation or condition in which the agent finds itself. It is the information that the environment provides to the agent, which then bases its decisions on this data.
|
564 |
+
|
565 |
+
**Agent**: This is the decision-maker. It uses a strategy, which may involve a neural network or another function approximator, to decide what actions to take based on the state it perceives.
|
566 |
+
|
567 |
+
**Action**: The decision taken by the agent, which will affect the environment.
|
568 |
+
|
569 |
+
**Reward**: After taking an action, the agent receives a reward (or penalty) from the environment. This reward is an indication of how good the action was in terms of achieving the goal.
|
570 |
+
|
571 |
+
**Updated State**: After an action is taken, the environment transitions to a new state. This new state and the reward feedback are then used by the agent to learn and refine its strategy.
|
572 |
+
|
573 |
+
## Evolutionary Neural Networks (ENNs)
|
574 |
+
|
575 |
+
A neural network architecture that uses evolutionary principles to evolve neural networks. Evolutionary Neural Networks are particularly effective for optimization problems, where they can evolve solutions over generations.
|
576 |
+
|
577 |
+
* **Usage**: Neuroevolution, optimization problems
|
578 |
+
* **Strengths**: Excellent performance on optimization problems, can learn complex policies
|
579 |
+
* **Caveats**: Computationally expensive, require large datasets
|
580 |
+
```{dot}
|
581 |
+
digraph ENN {
|
582 |
+
// Graph properties
|
583 |
+
node [shape=record];
|
584 |
+
|
585 |
+
// Nodes definitions
|
586 |
+
population [label="Initial Population\n(Neural Networks)" shape=ellipse];
|
587 |
+
selection [label="Selection" shape=box];
|
588 |
+
crossover [label="Crossover" shape=box];
|
589 |
+
mutation [label="Mutation" shape=box];
|
590 |
+
fitness [label="Fitness Evaluation" shape=box];
|
591 |
+
new_population [label="New Generation" shape=ellipse];
|
592 |
+
best_network [label="Best Performing Network" shape=ellipse, fillcolor=lightblue];
|
593 |
+
|
594 |
+
// Edges definitions
|
595 |
+
population -> selection;
|
596 |
+
selection -> crossover;
|
597 |
+
crossover -> mutation;
|
598 |
+
mutation -> fitness;
|
599 |
+
fitness -> new_population;
|
600 |
+
new_population -> selection [label="Next Generation"];
|
601 |
+
fitness -> best_network [label="If Optimal", style=dashed];
|
602 |
+
|
603 |
+
// Additional explanatory nodes
|
604 |
+
edge [style=dashed];
|
605 |
+
best_network -> new_population [label="Update Population", style=dotted];
|
606 |
+
}
|
607 |
+
```
|
608 |
+
|
609 |
+
**Initial Population**: This represents the initial set of neural networks. These networks might differ in architecture, weights, or hyperparameters.
|
610 |
+
|
611 |
+
**Selection**: Part of the evolutionary process where individual networks are selected based on their performance, often using a fitness function.
|
612 |
+
|
613 |
+
**Crossover**: A genetic operation used to combine features from two or more parent neural networks to create offspring. This simulates sexual reproduction.
|
614 |
+
|
615 |
+
**Mutation**: Introduces random variations to the offspring, potentially leading to new neural network configurations. This step enhances diversity within the population.
|
616 |
+
|
617 |
+
**Fitness Evaluation**: Each network in the population is evaluated based on how well it performs the given task. The fitness often determines which networks survive and reproduce.
|
618 |
+
|
619 |
+
**New Generation**: After selection, crossover, mutation, and evaluation, a new generation of neural networks is formed. This generation forms the new population for further evolution.
|
620 |
+
|
621 |
+
**Best Performing Network**: Out of all generations, the network that performs best on the task.
|
622 |
+
|
623 |
+
**Feedback Loops**:
|
624 |
+
|
625 |
+
- **Next Generation**: The cycle from selection to fitness evaluation and then back to selection with the new generation is a loop that continues until a satisfactory solution (network) is found.
|
626 |
+
|
627 |
+
- **If Optimal**: If during any fitness evaluation a network meets the predefined criteria or optimality, it may be selected as the final model.
|
628 |
+
|
629 |
+
## Spiking Neural Networks (SNNs)
|
630 |
+
|
631 |
+
A neural network architecture that uses spiking neurons to process data. SNNs are particularly effective for neuromorphic computing applications, where they can operate in energy-efficient ways.
|
632 |
+
|
633 |
+
* **Usage**: Neuromorphic computing, edge AI
|
634 |
+
* **Strengths**: Excellent performance on edge AI applications, energy-efficient
|
635 |
+
* **Caveats**: Limited software support, require specialized hardware
|
636 |
+
```{dot}
|
637 |
+
digraph SNN {
|
638 |
+
// Graph properties
|
639 |
+
node [shape=record];
|
640 |
+
|
641 |
+
// Nodes definitions
|
642 |
+
input_neurons [label="Input Neurons" shape=ellipse];
|
643 |
+
synaptic_layers [label="Synaptic Layers\n(Weighted Connections)" shape=box];
|
644 |
+
spiking_neurons [label="Spiking Neurons" shape=box];
|
645 |
+
output_neurons [label="Output Neurons" shape=ellipse];
|
646 |
+
threshold_mechanism [label="Threshold Mechanism" shape=box];
|
647 |
+
spike_train [label="Spike Train Output" shape=ellipse];
|
648 |
+
|
649 |
+
// Edges definitions
|
650 |
+
input_neurons -> synaptic_layers;
|
651 |
+
synaptic_layers -> spiking_neurons;
|
652 |
+
spiking_neurons -> threshold_mechanism;
|
653 |
+
threshold_mechanism -> output_neurons;
|
654 |
+
output_neurons -> spike_train;
|
655 |
+
|
656 |
+
// Additional explanatory nodes
|
657 |
+
edge [style=dashed];
|
658 |
+
synaptic_layers -> threshold_mechanism [label="Dynamic Weights", style=dashed];
|
659 |
+
}
|
660 |
+
```
|
661 |
+
|
662 |
+
**Input Neurons**: These neurons receive the initial input signals, which could be any time-varying signal or a pattern encoded in the timing of spikes.
|
663 |
+
|
664 |
+
**Synaptic Layers**: Represents the connections between neurons. In SNNs, these connections are often dynamic, changing over time based on the activity of the network (Hebbian learning principles).
|
665 |
+
|
666 |
+
**Spiking Neurons**: Neurons that operate using spikes, which are brief and discrete events typically caused by reaching a certain threshold in the neuron’s membrane potential.
|
667 |
+
|
668 |
+
**Threshold Mechanism**: A critical component in SNNs that determines when a neuron should fire based on its membrane potential. This mechanism can adapt based on the history of spikes and neuronal activity.
|
669 |
+
|
670 |
+
**Output Neurons**: Neurons that produce the final output of the network. These may also operate using spikes, especially in SNNs designed for specific tasks like motor control or sensory processing.
|
671 |
+
|
672 |
+
**Spike Train Output**: The output from the network is often in the form of a spike train, representing the timing and sequence of spikes from the output neurons.
|
673 |
+
|
674 |
+
**Dynamic Weights**: Indicates that the synaptic weights are not static and can change based on the spike timing differences between pre- and post-synaptic neurons (STDP - Spike-Timing-Dependent Plasticity).
|
675 |
+
|
676 |
+
## Conditional Random Fields (CRFs)
|
677 |
+
|
678 |
+
A probabilistic model that uses graphical models to model sequential data. CRFs are particularly effective for sequence labeling tasks, where they can model complex relationships between labels in a sequence.
|
679 |
+
|
680 |
+
* **Usage**: NLP, sequence labeling, information extraction
|
681 |
+
* **Strengths**: Excellent performance on sequential data, can model complex relationships
|
682 |
+
* **Caveats**: Computationally expensive, require large datasets
|
683 |
+
|
684 |
+
```{dot}
|
685 |
+
digraph CRF {
|
686 |
+
// Graph properties
|
687 |
+
node [shape=record];
|
688 |
+
|
689 |
+
// Nodes definitions
|
690 |
+
input_sequence [label="Input Sequence" shape=ellipse];
|
691 |
+
feature_extraction [label="Feature Extraction" shape=box];
|
692 |
+
crf_layer [label="CRF Layer" shape=box];
|
693 |
+
output_labels [label="Output Labels" shape=ellipse];
|
694 |
+
|
695 |
+
// Edges definitions
|
696 |
+
input_sequence -> feature_extraction;
|
697 |
+
feature_extraction -> crf_layer;
|
698 |
+
crf_layer -> output_labels;
|
699 |
+
|
700 |
+
// Additional nodes for clarity
|
701 |
+
state_transition [label="State Transition Features" shape=plaintext];
|
702 |
+
feature_extraction -> state_transition [style=dotted];
|
703 |
+
state_transition -> crf_layer [style=dotted];
|
704 |
+
}
|
705 |
+
```
|
706 |
+
|
707 |
+
**Input Sequence**: Represents the raw data input, such as sentences in text or other sequential data.
|
708 |
+
|
709 |
+
**Feature Extraction**: Processes the input data to extract features that are relevant for making predictions. This could include lexical features, part-of-speech tags, or contextual information in a natural language processing application.
|
710 |
+
|
711 |
+
**CRF Layer**: The core of the CRF model where the actual conditional random field is applied. This layer models the dependencies between labels in the sequence, considering both the input features and the labels of neighboring items in the sequence.
|
712 |
+
|
713 |
+
**Output Labels**: The final output of the CRF, which provides a label for each element in the input sequence. In the context of NLP, these might be tags for named entity recognition, part-of-speech tags, etc.
|
714 |
+
|
715 |
+
**State Transition Features**: This represents how CRFs utilize state transition features to model the relationships and dependencies between different labels in the sequence. These are not actual data flow but indicate the type of information that influences the CRF layer's decisions.
|
716 |
+
|
717 |
+
## Mixture of Experts (MoE)
|
718 |
+
|
719 |
+
A neural network architecture that consists of multiple expert networks (submodels), each specialized in different parts of the data or tasks. A gating network determines which expert(s) are most relevant for a given input. MoE is particularly effective for large-scale machine learning models, where it can dynamically route tasks to the most appropriate experts.
|
720 |
+
|
721 |
+
* **Usage**: Large-scale machine learning models, task-specific adaptations, dynamic routing of tasks
|
722 |
+
* **Strengths**: Highly scalable, capable of handling diverse tasks simultaneously, efficient use of resources by activating only relevant experts for each input.
|
723 |
+
* **Caveats**: Complex to implement and train, requires careful tuning to balance the load across experts and avoid overfitting in individual experts.
|
724 |
+
|
725 |
+
```{dot}
|
726 |
+
digraph MoE {
|
727 |
+
// Graph properties
|
728 |
+
node [shape=record];
|
729 |
+
|
730 |
+
// Nodes definitions
|
731 |
+
input_data [label="Input Data" shape=ellipse];
|
732 |
+
gating_network [label="Gating Network" shape=box];
|
733 |
+
expert1 [label="Expert 1" shape=box];
|
734 |
+
expert2 [label="Expert 2" shape=box];
|
735 |
+
expert3 [label="Expert 3" shape=box];
|
736 |
+
combined_output [label="Combined Output" shape=ellipse];
|
737 |
+
|
738 |
+
// Edges definitions
|
739 |
+
input_data -> gating_network;
|
740 |
+
gating_network -> expert1 [label="Weight"];
|
741 |
+
gating_network -> expert2 [label="Weight"];
|
742 |
+
gating_network -> expert3 [label="Weight"];
|
743 |
+
expert1 -> combined_output [label="Output 1"];
|
744 |
+
expert2 -> combined_output [label="Output 2"];
|
745 |
+
expert3 -> combined_output [label="Output 3"];
|
746 |
+
|
747 |
+
// Additional explanatory nodes
|
748 |
+
edge [style=dashed];
|
749 |
+
gating_network -> combined_output [label="Decision Weights", style=dotted];
|
750 |
+
}
|
751 |
+
```
|
752 |
+
|
753 |
+
**Input Data**: Represents the data being fed into the model. This could be anything from images, text, to structured data.
|
754 |
+
|
755 |
+
**Gating Network**: A crucial component that dynamically determines which expert model should handle the given input. It evaluates the input data and allocates weights to different experts based on their relevance to the current data point.
|
756 |
+
|
757 |
+
**Experts**: These are specialized models (expert1, expert2, expert3) that are trained on subsets of the data or specific types of tasks. Each expert processes the input independently.
|
758 |
+
|
759 |
+
**Combined Output**: The final output of the MoE model, which typically involves aggregating the outputs of the experts weighted by the gating network’s decisions.
|
760 |
+
|
761 |
+
**Weights**: These edges show how the gating network influences the contribution of each expert to the final decision. The weights are not fixed but are determined dynamically based on each input.
|
762 |
+
|
763 |
+
**Output 1, 2, 3**: These labels on the edges from experts to the combined output represent the contribution of each expert to the final model output. Each expert contributes its processed output, which is then combined based on the weights provided by the gating network.
|
src/theory/layers.qmd
CHANGED
@@ -1,103 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
* Usage: Receive input data, propagate it to subsequent layers
|
5 |
* Description: The first layer in a neural network that receives input data
|
6 |
* Strengths: Essential for processing input data, easy to implement
|
7 |
* Weaknesses: Limited functionality, no learning occurs in this layer
|
8 |
|
9 |
-
##
|
10 |
|
11 |
* Usage: Feature extraction, classification, regression
|
12 |
* Description: A layer where every input is connected to every output, using a weighted sum
|
13 |
* Strengths: Excellent for feature extraction, easy to implement, fast computation
|
14 |
* Weaknesses: Can be prone to overfitting, computationally expensive for large inputs
|
15 |
|
16 |
-
##
|
17 |
|
18 |
* Usage: Image classification, object detection, image segmentation
|
19 |
* Description: A layer that applies filters to small regions of the input data, scanning the input data horizontally and vertically
|
20 |
* Strengths: Excellent for image processing, reduces spatial dimensions, retains spatial hierarchy
|
21 |
* Weaknesses: Computationally expensive, require large datasets
|
22 |
|
23 |
-
##
|
24 |
|
25 |
* Usage: Image classification, object detection, image segmentation
|
26 |
* Description: A layer that reduces spatial dimensions by taking the maximum or average value across a region
|
27 |
* Strengths: Reduces spatial dimensions, reduces number of parameters, retains important features
|
28 |
* Weaknesses: Loses some information, can be sensitive to hyperparameters
|
29 |
|
30 |
-
##
|
31 |
|
32 |
* Usage: Natural Language Processing (NLP), sequence prediction, time series forecasting
|
33 |
* Description: A layer that processes sequential data, using hidden state to capture temporal dependencies
|
34 |
* Strengths: Excellent for sequential data, can model long-term dependencies
|
35 |
* Weaknesses: Suffers from vanishing gradients, difficult to train, computationally expensive
|
36 |
|
37 |
-
##
|
38 |
|
39 |
* Usage: NLP, sequence prediction, time series forecasting
|
40 |
* Description: A type of RNN that uses memory cells to learn long-term dependencies
|
41 |
* Strengths: Excellent for sequential data, can model long-term dependencies, mitigates vanishing gradients
|
42 |
* Weaknesses: Computationally expensive, require large datasets
|
43 |
|
44 |
-
##
|
45 |
|
46 |
* Usage: NLP, sequence prediction, time series forecasting
|
47 |
* Description: A simpler alternative to LSTM, using gates to control the flow of information
|
48 |
* Strengths: Faster computation, simpler than LSTM, easier to train
|
49 |
* Weaknesses: May not perform as well as LSTM, limited capacity to model long-term dependencies
|
50 |
|
51 |
-
##
|
52 |
|
53 |
* Usage: Normalizing inputs, stabilizing training, improving performance
|
54 |
* Description: A layer that normalizes inputs, reducing internal covariate shift
|
55 |
* Strengths: Improves training stability, accelerates training, improves performance
|
56 |
* Weaknesses: Requires careful tuning of hyperparameters, can be computationally expensive
|
57 |
|
58 |
-
##
|
59 |
|
60 |
* Usage: Regularization, preventing overfitting
|
61 |
* Description: A layer that randomly drops out neurons during training, reducing overfitting
|
62 |
* Strengths: Effective regularization technique, reduces overfitting, improves generalization
|
63 |
* Weaknesses: Can slow down training, requires careful tuning of hyperparameters
|
64 |
|
65 |
-
##
|
66 |
|
67 |
* Usage: Reshaping data, preparing data for dense layers
|
68 |
* Description: A layer that flattens input data into a one-dimensional array
|
69 |
* Strengths: Essential for preparing data for dense layers, easy to implement
|
70 |
* Weaknesses: Limited functionality, no learning occurs in this layer
|
71 |
|
72 |
-
##
|
73 |
|
74 |
* Usage: NLP, word embeddings, language modeling
|
75 |
* Description: A layer that converts categorical data into dense vectors
|
76 |
* Strengths: Excellent for NLP tasks, reduces dimensionality, captures semantic relationships
|
77 |
* Weaknesses: Require large datasets, can be computationally expensive
|
78 |
|
79 |
-
##
|
80 |
|
81 |
* Usage: NLP, machine translation, question answering
|
82 |
* Description: A layer that computes weighted sums of input data, focusing on relevant regions
|
83 |
* Strengths: Excellent for sequential data, can model long-range dependencies, improves performance
|
84 |
* Weaknesses: Computationally expensive, require careful tuning of hyperparameters
|
85 |
|
86 |
-
##
|
87 |
|
88 |
* Usage: Image segmentation, object detection, image generation
|
89 |
* Description: A layer that increases spatial dimensions, using interpolation or learned upsampling filters
|
90 |
* Strengths: Excellent for image processing, improves spatial resolution, enables image generation
|
91 |
* Weaknesses: Computationally expensive, require careful tuning of hyperparameters
|
92 |
|
93 |
-
##
|
94 |
|
95 |
* Usage: Normalizing inputs, stabilizing training, improving performance
|
96 |
* Description: A layer that normalizes inputs, reducing internal covariate shift
|
97 |
* Strengths: Improves training stability, accelerates training, improves performance
|
98 |
* Weaknesses: Requires careful tuning of hyperparameters, can be computationally expensive
|
99 |
|
100 |
-
##
|
101 |
|
102 |
* Usage: Introducing non-linearity, enhancing model capacity
|
103 |
* Description: A function that introduces non-linearity into the model, enabling complex representations
|
|
|
1 |
+
---
|
2 |
+
title: Layer Types
|
3 |
+
format:
|
4 |
+
html:
|
5 |
+
mermaid:
|
6 |
+
theme: default
|
7 |
+
---
|
8 |
|
9 |
+
Neural networks are complex architectures made up of various types of layers, each performing distinct functions that contribute to the network's ability to learn from data. Understanding the different types of layers and their specific roles is essential for designing effective neural network models. This knowledge not only helps in building tailored architectures for different tasks but also aids in optimizing performance and efficiency.
|
10 |
+
|
11 |
+
Each layer in a neural network processes the input data in a unique way, and the choice of layers depends on the problem at hand. For instance, convolutional layers are primarily used in image processing tasks due to their ability to capture spatial hierarchies, while recurrent layers are favored in tasks involving sequential data like natural language processing or time series analysis due to their ability to maintain a memory of previous inputs.
|
12 |
+
|
13 |
+
The structure of a neural network can be seen as a stack of layers where each layer feeds into the next, transforming the input step-by-step into a more abstract and ultimately useful form. The output of each layer becomes the input for the next until a final output is produced. This modular approach allows for the construction of deep learning models that can handle a wide range of complex tasks, from speech recognition and image classification to generating coherent text and beyond.
|
14 |
+
|
15 |
+
In the sections that follow, we will explore various types of layers commonly used in neural networks, discussing their usage, descriptions, strengths, and weaknesses. This will include foundational layers like input and dense layers, as well as more specialized ones like convolutional, recurrent, and attention layers. We'll also look at layers designed for specific functions such as normalization, regularization, and activation, each vital for enhancing the network's learning capability and stability. This comprehensive overview will provide a clearer understanding of how each layer works and how they can be combined to create powerful neural network models.
|
16 |
+
|
17 |
+
## Input Layers
|
18 |
|
19 |
* Usage: Receive input data, propagate it to subsequent layers
|
20 |
* Description: The first layer in a neural network that receives input data
|
21 |
* Strengths: Essential for processing input data, easy to implement
|
22 |
* Weaknesses: Limited functionality, no learning occurs in this layer
|
23 |
|
24 |
+
## Dense Layers (Fully Connected Layers)
|
25 |
|
26 |
* Usage: Feature extraction, classification, regression
|
27 |
* Description: A layer where every input is connected to every output, using a weighted sum
|
28 |
* Strengths: Excellent for feature extraction, easy to implement, fast computation
|
29 |
* Weaknesses: Can be prone to overfitting, computationally expensive for large inputs
|
30 |
|
31 |
+
## Convolutional Layers (Conv Layers)
|
32 |
|
33 |
* Usage: Image classification, object detection, image segmentation
|
34 |
* Description: A layer that applies filters to small regions of the input data, scanning the input data horizontally and vertically
|
35 |
* Strengths: Excellent for image processing, reduces spatial dimensions, retains spatial hierarchy
|
36 |
* Weaknesses: Computationally expensive, require large datasets
|
37 |
|
38 |
+
## Pooling Layers (Downsampling Layers)
|
39 |
|
40 |
* Usage: Image classification, object detection, image segmentation
|
41 |
* Description: A layer that reduces spatial dimensions by taking the maximum or average value across a region
|
42 |
* Strengths: Reduces spatial dimensions, reduces number of parameters, retains important features
|
43 |
* Weaknesses: Loses some information, can be sensitive to hyperparameters
|
44 |
|
45 |
+
## Recurrent Layers (RNNs)
|
46 |
|
47 |
* Usage: Natural Language Processing (NLP), sequence prediction, time series forecasting
|
48 |
* Description: A layer that processes sequential data, using hidden state to capture temporal dependencies
|
49 |
* Strengths: Excellent for sequential data, can model long-term dependencies
|
50 |
* Weaknesses: Suffers from vanishing gradients, difficult to train, computationally expensive
|
51 |
|
52 |
+
## Long Short-Term Memory (LSTM) Layers
|
53 |
|
54 |
* Usage: NLP, sequence prediction, time series forecasting
|
55 |
* Description: A type of RNN that uses memory cells to learn long-term dependencies
|
56 |
* Strengths: Excellent for sequential data, can model long-term dependencies, mitigates vanishing gradients
|
57 |
* Weaknesses: Computationally expensive, require large datasets
|
58 |
|
59 |
+
## Gated Recurrent Unit (GRU) Layers
|
60 |
|
61 |
* Usage: NLP, sequence prediction, time series forecasting
|
62 |
* Description: A simpler alternative to LSTM, using gates to control the flow of information
|
63 |
* Strengths: Faster computation, simpler than LSTM, easier to train
|
64 |
* Weaknesses: May not perform as well as LSTM, limited capacity to model long-term dependencies
|
65 |
|
66 |
+
## Batch Normalization Layers
|
67 |
|
68 |
* Usage: Normalizing inputs, stabilizing training, improving performance
|
69 |
* Description: A layer that normalizes inputs, reducing internal covariate shift
|
70 |
* Strengths: Improves training stability, accelerates training, improves performance
|
71 |
* Weaknesses: Requires careful tuning of hyperparameters, can be computationally expensive
|
72 |
|
73 |
+
## Dropout Layers
|
74 |
|
75 |
* Usage: Regularization, preventing overfitting
|
76 |
* Description: A layer that randomly drops out neurons during training, reducing overfitting
|
77 |
* Strengths: Effective regularization technique, reduces overfitting, improves generalization
|
78 |
* Weaknesses: Can slow down training, requires careful tuning of hyperparameters
|
79 |
|
80 |
+
## Flatten Layers
|
81 |
|
82 |
* Usage: Reshaping data, preparing data for dense layers
|
83 |
* Description: A layer that flattens input data into a one-dimensional array
|
84 |
* Strengths: Essential for preparing data for dense layers, easy to implement
|
85 |
* Weaknesses: Limited functionality, no learning occurs in this layer
|
86 |
|
87 |
+
## Embedding Layers
|
88 |
|
89 |
* Usage: NLP, word embeddings, language modeling
|
90 |
* Description: A layer that converts categorical data into dense vectors
|
91 |
* Strengths: Excellent for NLP tasks, reduces dimensionality, captures semantic relationships
|
92 |
* Weaknesses: Require large datasets, can be computationally expensive
|
93 |
|
94 |
+
## Attention Layers
|
95 |
|
96 |
* Usage: NLP, machine translation, question answering
|
97 |
* Description: A layer that computes weighted sums of input data, focusing on relevant regions
|
98 |
* Strengths: Excellent for sequential data, can model long-range dependencies, improves performance
|
99 |
* Weaknesses: Computationally expensive, require careful tuning of hyperparameters
|
100 |
|
101 |
+
## Upsampling Layers
|
102 |
|
103 |
* Usage: Image segmentation, object detection, image generation
|
104 |
* Description: A layer that increases spatial dimensions, using interpolation or learned upsampling filters
|
105 |
* Strengths: Excellent for image processing, improves spatial resolution, enables image generation
|
106 |
* Weaknesses: Computationally expensive, require careful tuning of hyperparameters
|
107 |
|
108 |
+
## Normalization Layers
|
109 |
|
110 |
* Usage: Normalizing inputs, stabilizing training, improving performance
|
111 |
* Description: A layer that normalizes inputs, reducing internal covariate shift
|
112 |
* Strengths: Improves training stability, accelerates training, improves performance
|
113 |
* Weaknesses: Requires careful tuning of hyperparameters, can be computationally expensive
|
114 |
|
115 |
+
## Activation Functions
|
116 |
|
117 |
* Usage: Introducing non-linearity, enhancing model capacity
|
118 |
* Description: A function that introduces non-linearity into the model, enabling complex representations
|
src/theory/metrics.qmd
CHANGED
@@ -1,150 +1,574 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
In machine learning, it's essential to evaluate the performance of a model to ensure it's accurate, reliable, and effective. There are various metrics to measure model performance, each with its strengths and limitations. Here's an overview of popular metrics, their pros and cons, and examples of tasks that apply to each.
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
MSE measures the average squared difference between predicted and actual values.
|
8 |
|
|
|
|
|
|
|
|
|
9 |
Pros:
|
10 |
|
11 |
* Easy to calculate
|
12 |
* Sensitive to outliers
|
13 |
-
|
|
|
14 |
Cons:
|
15 |
|
16 |
* Can be heavily influenced by extreme values
|
17 |
-
|
|
|
18 |
Example tasks:
|
19 |
|
20 |
* Regression tasks, such as predicting house prices or stock prices
|
21 |
* Time series forecasting
|
22 |
|
23 |
-
##
|
24 |
|
25 |
MAE measures the average absolute difference between predicted and actual values.
|
26 |
|
|
|
|
|
27 |
Pros:
|
28 |
|
29 |
* Robust to outliers
|
30 |
* Easy to interpret
|
31 |
-
|
|
|
32 |
Cons:
|
33 |
|
34 |
* Can be sensitive to skewness in the data
|
35 |
-
|
|
|
36 |
Example tasks:
|
37 |
|
38 |
* Regression tasks, such as predicting house prices or stock prices
|
39 |
* Time series forecasting
|
40 |
|
41 |
-
##
|
42 |
|
43 |
MAPE measures the average absolute percentage difference between predicted and actual values.
|
44 |
|
|
|
|
|
45 |
Pros:
|
46 |
|
47 |
* Easy to interpret
|
48 |
* Sensitive to relative errors
|
49 |
-
|
|
|
50 |
Cons:
|
51 |
|
52 |
* Can be sensitive to outliers
|
53 |
-
|
|
|
54 |
Example tasks:
|
55 |
|
56 |
* Regression tasks, such as predicting house prices or stock prices
|
57 |
* Time series forecasting
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
R² measures the proportion of variance in the dependent variable that's explained by the independent variables.
|
62 |
|
|
|
|
|
63 |
Pros:
|
64 |
|
65 |
* Easy to interpret
|
66 |
* Sensitive to the strength of the relationship
|
67 |
|
|
|
|
|
68 |
Cons:
|
69 |
|
70 |
* Can be sensitive to outliers
|
71 |
* Can be misleading for non-linear relationships
|
|
|
|
|
72 |
|
73 |
Example tasks:
|
74 |
|
75 |
* Regression tasks, such as predicting house prices or stock prices
|
76 |
* Feature selection
|
77 |
|
78 |
-
##
|
79 |
|
80 |
The Brier Score measures the average squared difference between predicted and actual probabilities.
|
81 |
|
|
|
|
|
82 |
Pros:
|
83 |
|
84 |
* Sensitive to the quality of the predictions
|
85 |
* Can handle multi-class classification tasks
|
86 |
-
|
|
|
87 |
Cons:
|
88 |
|
89 |
* Can be sensitive to the choice of threshold
|
|
|
|
|
90 |
|
91 |
Example tasks:
|
92 |
|
93 |
* Multi-class classification tasks, such as image classification
|
94 |
* Multi-label classification tasks
|
95 |
|
96 |
-
##
|
97 |
|
98 |
The F1 Score measures the harmonic mean of precision and recall.
|
99 |
|
|
|
|
|
100 |
Pros:
|
101 |
|
102 |
* Sensitive to the balance between precision and recall
|
103 |
* Can handle imbalanced datasets
|
104 |
|
|
|
|
|
105 |
Cons:
|
106 |
|
107 |
* Can be sensitive to the choice of threshold
|
108 |
|
|
|
|
|
|
|
109 |
Example tasks:
|
110 |
|
111 |
* Binary classification tasks, such as spam detection
|
112 |
* Multi-class classification tasks
|
113 |
|
114 |
-
##
|
115 |
|
116 |
MCC measures the correlation between predicted and actual labels.
|
117 |
|
|
|
|
|
118 |
Pros:
|
119 |
|
120 |
* Sensitive to the quality of the predictions
|
121 |
* Can handle imbalanced datasets
|
122 |
-
|
|
|
123 |
Cons:
|
124 |
|
125 |
* Can be sensitive to the choice of threshold
|
|
|
|
|
126 |
|
127 |
Example tasks:
|
128 |
|
129 |
* Binary classification tasks, such as spam detection
|
130 |
* Multi-class classification tasks
|
131 |
|
132 |
-
##
|
133 |
|
134 |
Log Loss measures the average log loss between predicted and actual probabilities.
|
135 |
|
|
|
|
|
136 |
Pros:
|
137 |
|
138 |
* Sensitive to the quality of the predictions
|
139 |
* Can handle multi-class classification tasks
|
140 |
-
|
|
|
141 |
Cons:
|
142 |
|
143 |
* Can be sensitive to the choice of threshold
|
|
|
|
|
144 |
|
145 |
Example tasks:
|
146 |
|
147 |
* Multi-class classification tasks, such as image classification
|
148 |
* Multi-label classification tasks
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Metrics
|
3 |
+
---
|
4 |
|
5 |
In machine learning, it's essential to evaluate the performance of a model to ensure it's accurate, reliable, and effective. There are various metrics to measure model performance, each with its strengths and limitations. Here's an overview of popular metrics, their pros and cons, and examples of tasks that apply to each.
|
6 |
|
7 |
+
For a quick [Overview](#overview) of metrics by use cases.
|
8 |
+
|
9 |
+
|
10 |
+
## Accuracy
|
11 |
+
|
12 |
+
Accuracy measures the ratio of correctly predicted observation to the total observations.
|
13 |
+
|
14 |
+
::: columns
|
15 |
+
::: {.column width="50%"}
|
16 |
+
Pros:
|
17 |
+
|
18 |
+
* Simplest and most intuitive metric.
|
19 |
+
:::
|
20 |
+
::: {.column width="50%"}
|
21 |
+
Cons:
|
22 |
+
|
23 |
+
* Can be misleading in the presence of imbalanced classes.
|
24 |
+
:::
|
25 |
+
:::
|
26 |
+
Example tasks:
|
27 |
+
|
28 |
+
* General classification tasks where classes are balanced.
|
29 |
+
* Entry-level benchmarking
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
## Mean Squared Error (MSE)
|
34 |
|
35 |
MSE measures the average squared difference between predicted and actual values.
|
36 |
|
37 |
+
|
38 |
+
::: columns
|
39 |
+
::: {.column width="50%"}
|
40 |
+
|
41 |
Pros:
|
42 |
|
43 |
* Easy to calculate
|
44 |
* Sensitive to outliers
|
45 |
+
:::
|
46 |
+
::: {.column width="50%"}
|
47 |
Cons:
|
48 |
|
49 |
* Can be heavily influenced by extreme values
|
50 |
+
:::
|
51 |
+
:::
|
52 |
Example tasks:
|
53 |
|
54 |
* Regression tasks, such as predicting house prices or stock prices
|
55 |
* Time series forecasting
|
56 |
|
57 |
+
## Mean Absolute Error (MAE)
|
58 |
|
59 |
MAE measures the average absolute difference between predicted and actual values.
|
60 |
|
61 |
+
::: columns
|
62 |
+
::: {.column width="50%"}
|
63 |
Pros:
|
64 |
|
65 |
* Robust to outliers
|
66 |
* Easy to interpret
|
67 |
+
:::
|
68 |
+
::: {.column width="50%"}
|
69 |
Cons:
|
70 |
|
71 |
* Can be sensitive to skewness in the data
|
72 |
+
:::
|
73 |
+
:::
|
74 |
Example tasks:
|
75 |
|
76 |
* Regression tasks, such as predicting house prices or stock prices
|
77 |
* Time series forecasting
|
78 |
|
79 |
+
## Mean Absolute Percentage Error (MAPE)
|
80 |
|
81 |
MAPE measures the average absolute percentage difference between predicted and actual values.
|
82 |
|
83 |
+
::: columns
|
84 |
+
::: {.column width="50%"}
|
85 |
Pros:
|
86 |
|
87 |
* Easy to interpret
|
88 |
* Sensitive to relative errors
|
89 |
+
:::
|
90 |
+
::: {.column width="50%"}
|
91 |
Cons:
|
92 |
|
93 |
* Can be sensitive to outliers
|
94 |
+
:::
|
95 |
+
:::
|
96 |
Example tasks:
|
97 |
|
98 |
* Regression tasks, such as predicting house prices or stock prices
|
99 |
* Time series forecasting
|
100 |
|
101 |
+
|
102 |
+
## Binary Crossentropy
|
103 |
+
|
104 |
+
Binary Crossentropy measures the error rate between the true label and the predicted probability for binary classification tasks.
|
105 |
+
|
106 |
+
::: columns
|
107 |
+
::: {.column width="50%"}
|
108 |
+
Pros:
|
109 |
+
|
110 |
+
* Effective for measuring performance on binary classification problems.
|
111 |
+
* Directly optimizes the classification threshold.
|
112 |
+
|
113 |
+
:::
|
114 |
+
::: {.column width="50%"}
|
115 |
+
Cons:
|
116 |
+
|
117 |
+
* Not suitable for multi-class classification tasks.
|
118 |
+
* Sensitive to the balance of classes in the dataset.
|
119 |
+
:::
|
120 |
+
:::
|
121 |
+
Example tasks:
|
122 |
+
|
123 |
+
* Predicting customer churn (yes or no).
|
124 |
+
* Medical diagnostics with binary outcomes.
|
125 |
+
|
126 |
+
## Categorical Crossentropy
|
127 |
+
|
128 |
+
Categorical Crossentropy is used to measure the error rate between the true label and the predicted probabilities for multi-class classification tasks where each class is mutually exclusive.
|
129 |
+
|
130 |
+
::: columns
|
131 |
+
::: {.column width="50%"}
|
132 |
+
Pros:
|
133 |
+
|
134 |
+
* Well-suited for multi-class classification problems.
|
135 |
+
* Directly optimizes probability distribution across all classes.
|
136 |
+
|
137 |
+
:::
|
138 |
+
::: {.column width="50%"}
|
139 |
+
Cons:
|
140 |
+
|
141 |
+
* Requires one-hot encoding of labels.
|
142 |
+
* Can be computationally intensive with many class labels.
|
143 |
+
:::
|
144 |
+
:::
|
145 |
+
Example tasks:
|
146 |
+
|
147 |
+
* Image classification with multiple classes.
|
148 |
+
* Text classification into various categories.
|
149 |
+
|
150 |
+
## Sparse Categorical Crossentropy
|
151 |
+
|
152 |
+
Sparse Categorical Crossentropy is similar to Categorical Crossentropy but used for multi-class classification tasks where the classes are encoded as integers, not one-hot vectors.
|
153 |
+
|
154 |
+
::: columns
|
155 |
+
::: {.column width="50%"}
|
156 |
+
Pros:
|
157 |
+
|
158 |
+
* More memory efficient than Categorical Crossentropy when dealing with many classes.
|
159 |
+
* Eliminates the need for one-hot encoding of labels.
|
160 |
+
|
161 |
+
:::
|
162 |
+
::: {.column width="50%"}
|
163 |
+
Cons:
|
164 |
+
|
165 |
+
* Not suitable for problems where multiple categories may be applicable to a single observation.
|
166 |
+
* Requires careful handling of class indices to ensure they are mapped correctly.
|
167 |
+
:::
|
168 |
+
:::
|
169 |
+
Example tasks:
|
170 |
+
|
171 |
+
* Classifying news articles into a predefined set of topics.
|
172 |
+
* Recognizing the type of product from its description in a single-label system.
|
173 |
+
|
174 |
+
|
175 |
+
## R-Squared (R²)
|
176 |
|
177 |
R² measures the proportion of variance in the dependent variable that's explained by the independent variables.
|
178 |
|
179 |
+
::: columns
|
180 |
+
::: {.column width="50%"}
|
181 |
Pros:
|
182 |
|
183 |
* Easy to interpret
|
184 |
* Sensitive to the strength of the relationship
|
185 |
|
186 |
+
:::
|
187 |
+
::: {.column width="50%"}
|
188 |
Cons:
|
189 |
|
190 |
* Can be sensitive to outliers
|
191 |
* Can be misleading for non-linear relationships
|
192 |
+
:::
|
193 |
+
:::
|
194 |
|
195 |
Example tasks:
|
196 |
|
197 |
* Regression tasks, such as predicting house prices or stock prices
|
198 |
* Feature selection
|
199 |
|
200 |
+
## Brier Score
|
201 |
|
202 |
The Brier Score measures the average squared difference between predicted and actual probabilities.
|
203 |
|
204 |
+
::: columns
|
205 |
+
::: {.column width="50%"}
|
206 |
Pros:
|
207 |
|
208 |
* Sensitive to the quality of the predictions
|
209 |
* Can handle multi-class classification tasks
|
210 |
+
:::
|
211 |
+
::: {.column width="50%"}
|
212 |
Cons:
|
213 |
|
214 |
* Can be sensitive to the choice of threshold
|
215 |
+
:::
|
216 |
+
:::
|
217 |
|
218 |
Example tasks:
|
219 |
|
220 |
* Multi-class classification tasks, such as image classification
|
221 |
* Multi-label classification tasks
|
222 |
|
223 |
+
## F1 Score
|
224 |
|
225 |
The F1 Score measures the harmonic mean of precision and recall.
|
226 |
|
227 |
+
::: columns
|
228 |
+
::: {.column width="50%"}
|
229 |
Pros:
|
230 |
|
231 |
* Sensitive to the balance between precision and recall
|
232 |
* Can handle imbalanced datasets
|
233 |
|
234 |
+
:::
|
235 |
+
::: {.column width="50%"}
|
236 |
Cons:
|
237 |
|
238 |
* Can be sensitive to the choice of threshold
|
239 |
|
240 |
+
:::
|
241 |
+
:::
|
242 |
+
|
243 |
Example tasks:
|
244 |
|
245 |
* Binary classification tasks, such as spam detection
|
246 |
* Multi-class classification tasks
|
247 |
|
248 |
+
## Matthews Correlation Coefficient (MCC)
|
249 |
|
250 |
MCC measures the correlation between predicted and actual labels.
|
251 |
|
252 |
+
::: columns
|
253 |
+
::: {.column width="50%"}
|
254 |
Pros:
|
255 |
|
256 |
* Sensitive to the quality of the predictions
|
257 |
* Can handle imbalanced datasets
|
258 |
+
:::
|
259 |
+
::: {.column width="50%"}
|
260 |
Cons:
|
261 |
|
262 |
* Can be sensitive to the choice of threshold
|
263 |
+
:::
|
264 |
+
:::
|
265 |
|
266 |
Example tasks:
|
267 |
|
268 |
* Binary classification tasks, such as spam detection
|
269 |
* Multi-class classification tasks
|
270 |
|
271 |
+
## Log Loss
|
272 |
|
273 |
Log Loss measures the average log loss between predicted and actual probabilities.
|
274 |
|
275 |
+
::: columns
|
276 |
+
::: {.column width="50%"}
|
277 |
Pros:
|
278 |
|
279 |
* Sensitive to the quality of the predictions
|
280 |
* Can handle multi-class classification tasks
|
281 |
+
:::
|
282 |
+
::: {.column width="50%"}
|
283 |
Cons:
|
284 |
|
285 |
* Can be sensitive to the choice of threshold
|
286 |
+
:::
|
287 |
+
:::
|
288 |
|
289 |
Example tasks:
|
290 |
|
291 |
* Multi-class classification tasks, such as image classification
|
292 |
* Multi-label classification tasks
|
293 |
|
294 |
+
## Area Under the Receiver Operating Characteristic Curve (AUC-ROC)
|
295 |
+
|
296 |
+
The AUC-ROC curve is a performance measurement for classification problems at various threshold settings.
|
297 |
+
|
298 |
+
::: columns
|
299 |
+
::: {.column width="50%"}
|
300 |
+
Pros:
|
301 |
+
|
302 |
+
* Measures the ability of the model to discriminate between classes.
|
303 |
+
* Useful for binary classification problems.
|
304 |
+
:::
|
305 |
+
::: {.column width="50%"}
|
306 |
+
Cons:
|
307 |
+
|
308 |
+
* Less effective when dealing with highly imbalanced datasets.
|
309 |
+
* Does not differentiate between types of errors.
|
310 |
+
:::
|
311 |
+
:::
|
312 |
+
|
313 |
+
Example tasks:
|
314 |
+
|
315 |
+
* Medical diagnosis classification.
|
316 |
+
* Spam detection in emails.
|
317 |
+
|
318 |
+
## Area Under the Precision-Recall Curve (AUC-PR)
|
319 |
+
|
320 |
+
The AUC-PR summarizes the precision-recall curve as a single number, which is the weighted average of precisions achieved at each threshold.
|
321 |
+
|
322 |
+
::: columns
|
323 |
+
::: {.column width="50%"}
|
324 |
+
Pros:
|
325 |
+
|
326 |
+
* Focuses on the positive class more than AUC-ROC, useful in imbalanced datasets.
|
327 |
+
:::
|
328 |
+
::: {.column width="50%"}
|
329 |
+
Cons:
|
330 |
+
|
331 |
+
* More complex to interpret and explain than AUC-ROC.
|
332 |
+
:::
|
333 |
+
:::
|
334 |
+
|
335 |
+
Example tasks:
|
336 |
+
|
337 |
+
* Information retrieval.
|
338 |
+
* Ranking tasks where positive class prevalence is low.
|
339 |
+
|
340 |
+
|
341 |
+
## Precision and Recall
|
342 |
+
|
343 |
+
Precision measures the ratio of correctly predicted positive observations to the total predicted positives. Recall measures the ratio of correctly predicted positive observations to all observations in actual class.
|
344 |
+
|
345 |
+
::: columns
|
346 |
+
::: {.column width="50%"}
|
347 |
+
Pros:
|
348 |
+
|
349 |
+
* Provide a more detailed understanding of model performance than accuracy, especially in imbalanced datasets.
|
350 |
+
|
351 |
+
:::
|
352 |
+
::: {.column width="50%"}
|
353 |
+
Cons:
|
354 |
+
|
355 |
+
* High precision can sometimes be achieved at the expense of recall, and vice versa (precision-recall trade-off).
|
356 |
+
:::
|
357 |
+
:::
|
358 |
+
|
359 |
+
Example tasks:
|
360 |
+
|
361 |
+
* Document classification.
|
362 |
+
* Customer churn prediction.
|
363 |
+
|
364 |
+
|
365 |
+
## Cohen’s Kappa
|
366 |
+
|
367 |
+
Cohen’s Kappa measures the agreement between two raters who each classify N items into C mutually exclusive categories.
|
368 |
+
|
369 |
+
::: columns
|
370 |
+
::: {.column width="50%"}
|
371 |
+
Pros:
|
372 |
+
|
373 |
+
* Accounts for the possibility of the agreement occurring by chance.
|
374 |
+
* More robust than simple accuracy.
|
375 |
+
:::
|
376 |
+
::: {.column width="50%"}
|
377 |
+
Cons:
|
378 |
+
|
379 |
+
* Can be difficult to interpret, especially with multiple classes.
|
380 |
+
:::
|
381 |
+
:::
|
382 |
+
|
383 |
+
Example tasks:
|
384 |
+
|
385 |
+
* Multi-rater reliability test.
|
386 |
+
* Medical image classification where multiple doctors provide diagnoses.
|
387 |
+
|
388 |
+
Here's the explanation of the additional metrics (BLEU, ROUGE, METEOR, MCC, Precision@k, Recall@k) in the format used previously in your document.
|
389 |
+
|
390 |
+
---
|
391 |
+
|
392 |
+
## BLEU (Bilingual Evaluation Understudy)
|
393 |
+
|
394 |
+
BLEU is a metric used to evaluate the quality of text which has been machine-translated from one natural language to another.
|
395 |
+
|
396 |
+
::: columns
|
397 |
+
::: {.column width="50%"}
|
398 |
+
Pros:
|
399 |
+
|
400 |
+
* Provides a quantitative measure for the quality of machine translation.
|
401 |
+
* Widely used and easy to compute.
|
402 |
+
|
403 |
+
:::
|
404 |
+
::: {.column width="50%"}
|
405 |
+
Cons:
|
406 |
+
|
407 |
+
* May not capture the fluency and grammaticality of the translation.
|
408 |
+
* Focuses more on precision than recall.
|
409 |
+
|
410 |
+
:::
|
411 |
+
:::
|
412 |
+
Example tasks:
|
413 |
+
|
414 |
+
* Machine translation evaluation.
|
415 |
+
* Automated text generation assessment.
|
416 |
+
|
417 |
+
## ROUGE (Recall-Oriented Understudy for Gisting Evaluation)
|
418 |
+
|
419 |
+
ROUGE is used to evaluate automatic summarization of texts as well as machine translation.
|
420 |
+
|
421 |
+
::: columns
|
422 |
+
::: {.column width="50%"}
|
423 |
+
Pros:
|
424 |
+
|
425 |
+
* Measures both the quality and quantity of match between reference and generated summaries.
|
426 |
+
* Supports evaluation of multiple types of summaries (extractive, abstractive).
|
427 |
+
|
428 |
+
:::
|
429 |
+
::: {.column width="50%"}
|
430 |
+
Cons:
|
431 |
+
|
432 |
+
* Can be biased towards extractive summarization techniques.
|
433 |
+
* Depends heavily on the quality of reference summaries.
|
434 |
+
|
435 |
+
:::
|
436 |
+
:::
|
437 |
+
Example tasks:
|
438 |
+
|
439 |
+
* Text summarization.
|
440 |
+
* Machine translation evaluation.
|
441 |
+
|
442 |
+
## METEOR (Metric for Evaluation of Translation with Explicit ORdering)
|
443 |
+
|
444 |
+
METEOR is a metric for evaluating machine translation that extends beyond simple overlap of vocabulary.
|
445 |
+
|
446 |
+
::: columns
|
447 |
+
::: {.column width="50%"}
|
448 |
+
Pros:
|
449 |
+
|
450 |
+
* Incorporates synonyms, stemming, and paraphrase matching.
|
451 |
+
* Aligns closely with human judgment compared to BLEU.
|
452 |
+
|
453 |
+
:::
|
454 |
+
::: {.column width="50%"}
|
455 |
+
Cons:
|
456 |
+
|
457 |
+
* More complex to compute than BLEU.
|
458 |
+
* Can require extensive linguistic resources.
|
459 |
+
|
460 |
+
:::
|
461 |
+
:::
|
462 |
+
Example tasks:
|
463 |
+
|
464 |
+
* Machine translation.
|
465 |
+
* Natural language generation.
|
466 |
+
|
467 |
+
## Matthews Correlation Coefficient (MCC)
|
468 |
+
|
469 |
+
MCC is a measure of the quality of binary classifications.
|
470 |
+
|
471 |
+
::: columns
|
472 |
+
::: {.column width="50%"}
|
473 |
+
Pros:
|
474 |
+
|
475 |
+
* Provides a balanced measure even if classes are of very different sizes.
|
476 |
+
* Considers true and false positives and negatives.
|
477 |
+
|
478 |
+
:::
|
479 |
+
::: {.column width="50%"}
|
480 |
+
Cons:
|
481 |
+
|
482 |
+
* Not applicable to multi-class classifications.
|
483 |
+
* Can be difficult to interpret in non-binary classification contexts.
|
484 |
+
|
485 |
+
:::
|
486 |
+
:::
|
487 |
+
Example tasks:
|
488 |
+
|
489 |
+
* Binary classification tasks such as spam detection.
|
490 |
+
* Medical imaging classification.
|
491 |
+
|
492 |
+
## Precision-k
|
493 |
+
|
494 |
+
Precision-k measures the proportion of relevant items found in the top-k recommendations of a system.
|
495 |
+
|
496 |
+
::: columns
|
497 |
+
::: {.column width="50%"}
|
498 |
+
Pros:
|
499 |
+
|
500 |
+
* Useful for evaluating ranking systems where the order of items is important.
|
501 |
+
* Easy to understand and implement.
|
502 |
+
|
503 |
+
:::
|
504 |
+
::: {.column width="50%"}
|
505 |
+
Cons:
|
506 |
+
|
507 |
+
* Does not consider the actual ranking of items beyond the cutoff of k.
|
508 |
+
* Sensitive to the choice of k.
|
509 |
+
|
510 |
+
:::
|
511 |
+
:::
|
512 |
+
Example tasks:
|
513 |
+
|
514 |
+
* Recommender systems.
|
515 |
+
* Search engine result ranking.
|
516 |
+
|
517 |
+
## Recall-k
|
518 |
+
|
519 |
+
Recall-k measures the proportion of relevant items that appear in the top-k recommendations out of all relevant items.
|
520 |
+
|
521 |
+
::: columns
|
522 |
+
::: {.column width="50%"}
|
523 |
+
Pros:
|
524 |
+
|
525 |
+
* Indicates the ability of a system to retrieve highly relevant items.
|
526 |
+
* Useful for systems where retrieving most relevant items is critical.
|
527 |
+
|
528 |
+
:::
|
529 |
+
::: {.column width="50%"}
|
530 |
+
Cons:
|
531 |
+
|
532 |
+
* Like precision-k, it is sensitive to the choice of k.
|
533 |
+
* Does not account for the relevance of items outside the top k.
|
534 |
+
|
535 |
+
:::
|
536 |
+
:::
|
537 |
+
Example tasks:
|
538 |
+
|
539 |
+
* Information retrieval.
|
540 |
+
* Content-based filtering in recommender systems.
|
541 |
+
|
542 |
+
|
543 |
+
## Overview
|
544 |
+
Below is a table outlining typical example cases in machine learning and the corresponding metrics that could be effectively used for evaluating model performance in each case. This table helps in quickly identifying which metrics are most applicable for a given type of task.
|
545 |
+
|
546 |
+
| Use Case | Relevant Metrics |
|
547 |
+
|------------------------------------------|---------------------------------------------------------------------------------------|
|
548 |
+
| **Predicting House Prices** | MSE, MAE, R² |
|
549 |
+
| **Stock Price Forecasting** | MSE, MAE, R² |
|
550 |
+
| **Medical Diagnosis Classification** | Accuracy, F1 Score, AUC-ROC, Cohen’s Kappa |
|
551 |
+
| **Spam Detection** | Precision, Recall, F1 Score, MCC, AUC-ROC |
|
552 |
+
| **Image Classification** | Accuracy, F1 Score, Brier Score, Log Loss, AUC-ROC |
|
553 |
+
| **Multi-label Classification** | F1 Score, Brier Score, Log Loss, AUC-PR |
|
554 |
+
| **Customer Churn Prediction** | Accuracy, Precision, Recall, F1 Score |
|
555 |
+
| **Document Classification** | Accuracy, Precision, Recall, F1 Score |
|
556 |
+
| **Ranking Tasks** | AUC-PR, Log Loss |
|
557 |
+
| **Multi-rater Reliability Tests** | Cohen’s Kappa |
|
558 |
+
| **Multi-class Text Classification** | Categorical Crossentropy, Sparse Categorical Crossentropy, Accuracy, F1 Score |
|
559 |
+
| **Binary Classification of Customer Data** | Binary Crossentropy, Accuracy, Precision, Recall, F1 Score |
|
560 |
+
| **Sentiment Analysis** | Accuracy, Precision, Recall, F1 Score, MAE, Log Loss |
|
561 |
+
| **Time Series Forecasting** | MSE, MAE, R², MAPE |
|
562 |
+
| **Anomaly Detection** | F1 Score, Precision, Recall, MCC |
|
563 |
+
| **Search Engine Ranking** | Precision-k, Recall-k, AUC-ROC, AUC-PR |
|
564 |
+
| **Natural Language Generation** | BLEU, ROUGE, METEOR |
|
565 |
+
| **Recommendation Systems** | Precision, Recall, F1 Score, AUC-ROC |
|
566 |
+
| **Real-time Bidding in Advertising** | Log Loss, ROC AUC |
|
567 |
+
| **User Engagement Prediction** | MSE, MAE, Log Loss, Brier Score |
|
568 |
+
| **Protein Structure Prediction** | Accuracy, MCC, F1 Score |
|
569 |
+
| **Drug Response Modeling** | MSE, MAE, R², Log Loss |
|
570 |
+
| **Credit Scoring** | AUC-ROC, Precision, Recall, F1 Score, Log Loss |
|
571 |
+
| **Fraud Detection** | Accuracy, Precision, Recall, F1 Score, AUC-ROC |
|
572 |
+
|
573 |
+
Each task is paired with metrics that are most appropriate based on the nature of the data and the specific requirements of the task.
|
574 |
+
This approach helps ensure that the evaluation of a model is not only accurate but also relevant to the goals of the analysis.
|
src/theory/optimizers.qmd
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Optimizers
|
3 |
+
---
|
4 |
+
|
5 |
+
Optimizers play a critical role in training neural networks by updating the network's weights based on the loss gradient. The choice of an optimizer can significantly impact the speed and quality of training, making it a fundamental component of deep learning. This page explores various types of optimizers, their mechanisms, and their applications, providing insights into how they work and why certain optimizers are preferred for specific tasks.
|
6 |
+
|
7 |
+
[Slideshow](optimizers_slideshow.qmd)
|
8 |
+
|
9 |
+
## The Role of Optimizers
|
10 |
+
|
11 |
+
The primary function of an optimizer is to minimize (or maximize) a loss function or objective function that measures how well the model performs on a given task. This is achieved by iteratively adjusting the weights of the network. Optimizers not only help in converging to a solution more quickly but also affect the stability and quality of the model. They navigate the complex, high-dimensional landscape formed by a model's weights and aim to find a combination that results in the best possible predictions.
|
12 |
+
|
13 |
+
## Types of Optimizers
|
14 |
+
|
15 |
+
### Gradient Descent
|
16 |
+
|
17 |
+
The simplest form of an optimizer, which updates the weights by moving in the direction of the negative gradient of the objective function with respect to the network's weights.
|
18 |
+
|
19 |
+
* **Usage**: Basic learning tasks, small datasets
|
20 |
+
* **Strengths**: Simple, easy to understand and implement
|
21 |
+
* **Caveats**: Slow convergence, sensitive to the choice of learning rate, can get stuck in local minima
|
22 |
+
|
23 |
+
### Stochastic Gradient Descent (SGD)
|
24 |
+
|
25 |
+
An extension of the gradient descent algorithm that updates the model's weights using only a single sample or a mini-batch of samples, which makes the training process much faster.
|
26 |
+
|
27 |
+
* **Usage**: General machine learning and deep learning tasks
|
28 |
+
* **Strengths**: Faster convergence than standard gradient descent, less memory intensive
|
29 |
+
* **Caveats**: Variability in the training updates can lead to unstable convergence
|
30 |
+
|
31 |
+
### Momentum
|
32 |
+
|
33 |
+
SGD with momentum considers the past gradients to smooth out the update. It helps accelerate SGD in the relevant direction and dampens oscillations.
|
34 |
+
|
35 |
+
* **Usage**: Deep networks, training with high variability or sparse gradients
|
36 |
+
* **Strengths**: Faster convergence than SGD, reduces oscillations in updates
|
37 |
+
* **Caveats**: Additional hyperparameter to tune (momentum coefficient)
|
38 |
+
|
39 |
+
### Nesterov Accelerated Gradient (NAG)
|
40 |
+
|
41 |
+
A variant of the momentum method that helps to speed up training. NAG first makes a big jump in the direction of the previous accumulated gradient, then measures the gradient where it ends up and makes a correction.
|
42 |
+
|
43 |
+
* **Usage**: Convolutional neural networks, large-scale neural networks
|
44 |
+
* **Strengths**: Often converges faster than momentum
|
45 |
+
* **Caveats**: Can overshoot in settings with noisy data
|
46 |
+
|
47 |
+
### Adagrad
|
48 |
+
|
49 |
+
An algorithm that adapts the learning rate to the parameters, performing larger updates for infrequent parameters and smaller updates for frequent parameters. Useful for sparse data.
|
50 |
+
|
51 |
+
* **Usage**: Sparse datasets, NLP and image recognition
|
52 |
+
* **Strengths**: Removes the need to manually tune the learning rate
|
53 |
+
* **Caveats**: The accumulated squared gradients in the denominator can cause the learning rate to shrink and become extremely small
|
54 |
+
|
55 |
+
### RMSprop
|
56 |
+
|
57 |
+
Addresses the radically diminishing learning rates of Adagrad by using a moving average of squared gradients to normalize the gradient. This ensures that the learning rate does not decrease too quickly.
|
58 |
+
|
59 |
+
* **Usage**: Non-stationary objectives, training RNNs
|
60 |
+
* **Strengths**: Balances the step size decrease, making it more robust
|
61 |
+
* **Caveats**: Still requires setting a learning rate
|
62 |
+
|
63 |
+
### Adam (Adaptive Moment Estimation)
|
64 |
+
|
65 |
+
Combines the advantages of Adagrad and RMSprop and calculates an exponential moving average of the gradients and the squared gradients. It can handle non-stationary objectives and problems with very noisy and/or sparse gradients.
|
66 |
+
|
67 |
+
* **Usage**: Broad range of applications from general machine learning to deep learning
|
68 |
+
* **Strengths**: Computationally efficient, little memory requirement, well suited for problems with lots of data and/or parameters
|
69 |
+
* **Caveats**: Can sometimes lead to suboptimal solutions for some problems
|
70 |
+
|
71 |
+
### AdamW
|
72 |
+
|
73 |
+
AdamW is a variant of the Adam optimizer that incorporates weight decay directly into the optimization process. By decoupling the weight decay from the optimization steps, AdamW tends to outperform the standard Adam, especially in settings where regularizing and preventing overfitting are crucial.
|
74 |
+
|
75 |
+
* **Usage**: Training deep neural networks across a wide range of tasks including classification and regression where regularization is key.
|
76 |
+
* **Strengths**: Addresses some of the issues found in Adam related to poor generalization performance. It provides a more effective way to use L2 regularization, avoiding common pitfalls of Adam related to the scale of the updates.
|
77 |
+
* **Caveats**: Like Adam, it requires tuning of hyperparameters such as the learning rate and weight decay coefficients. It may still suffer from some of the convergence issues inherent to adaptive gradient methods but to a lesser extent.
|
78 |
+
|
79 |
+
### AdaMax, Nadam
|
80 |
+
|
81 |
+
Variations of Adam with modifications for better convergence in specific scenarios.
|
82 |
+
|
83 |
+
* **Usage**: Specific optimizations where Adam shows suboptimal behavior
|
84 |
+
* **Strengths**: Provides alternative ways to scale the learning rates
|
85 |
+
* **Caveats**: Can be more sensitive to hyperparameter settings
|
86 |
+
|
87 |
+
## Conclusion
|
88 |
+
|
89 |
+
Choosing the right optimizer is crucial as it directly influences the efficiency and outcome of training neural networks. While some optimizers are better suited for large datasets and models, others might be designed to handle specific types of data or learning tasks more effectively. Understanding the strengths and limitations of each optimizer helps in selecting the most appropriate one for a given problem, leading to better performance and more robust models.
|
src/theory/optimizers_slideshow.qmd
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: "Optimizers in Neural Networks"
|
3 |
+
author: "Sébastien De Greef"
|
4 |
+
format:
|
5 |
+
revealjs:
|
6 |
+
theme: solarized
|
7 |
+
navigation-mode: grid
|
8 |
+
controls-layout: bottom-right
|
9 |
+
controls-tutorial: true
|
10 |
+
notebook-links: false
|
11 |
+
crossref:
|
12 |
+
lof-title: "List of Figures"
|
13 |
+
number-sections: false
|
14 |
+
---
|
15 |
+
|
16 |
+
## Introduction to Optimizers
|
17 |
+
|
18 |
+
Optimizers are crucial for training neural networks by updating the network's weights based on the loss gradient. They impact the training speed, quality, and the model's final performance.
|
19 |
+
|
20 |
+
---
|
21 |
+
|
22 |
+
## Role of Optimizers
|
23 |
+
|
24 |
+
- **Function**: Minimize the loss function
|
25 |
+
- **Mechanism**: Iteratively adjust the weights
|
26 |
+
- **Impact**: Affect efficiency, accuracy, and model feasibility
|
27 |
+
|
28 |
+
---
|
29 |
+
|
30 |
+
## Gradient Descent
|
31 |
+
|
32 |
+
- **Usage**: Basic learning tasks, small datasets
|
33 |
+
- **Strengths**: Simple, easy to understand
|
34 |
+
- **Caveats**: Slow convergence, sensitive to learning rate settings
|
35 |
+
|
36 |
+
---
|
37 |
+
|
38 |
+
## Stochastic Gradient Descent (SGD)
|
39 |
+
|
40 |
+
- **Usage**: General learning tasks
|
41 |
+
- **Strengths**: Faster than batch gradient descent
|
42 |
+
- **Caveats**: Higher variance in updates
|
43 |
+
|
44 |
+
---
|
45 |
+
|
46 |
+
## Momentum
|
47 |
+
|
48 |
+
- **Usage**: Training deep networks
|
49 |
+
- **Strengths**: Accelerates SGD, dampens oscillations
|
50 |
+
- **Caveats**: Additional hyperparameter (momentum)
|
51 |
+
|
52 |
+
---
|
53 |
+
|
54 |
+
## Nesterov Accelerated Gradient (NAG)
|
55 |
+
|
56 |
+
- **Usage**: Large-scale neural networks
|
57 |
+
- **Strengths**: Faster convergence than Momentum
|
58 |
+
- **Caveats**: Can overshoot in noisy settings
|
59 |
+
|
60 |
+
---
|
61 |
+
|
62 |
+
## Adagrad
|
63 |
+
|
64 |
+
- **Usage**: Sparse data problems like NLP and image recognition
|
65 |
+
- **Strengths**: Adapts the learning rate to the parameters
|
66 |
+
- **Caveats**: Shrinking learning rate over time
|
67 |
+
|
68 |
+
---
|
69 |
+
|
70 |
+
## RMSprop
|
71 |
+
|
72 |
+
- **Usage**: Non-stationary objectives, training RNNs
|
73 |
+
- **Strengths**: Balances decreasing learning rates
|
74 |
+
- **Caveats**: Still requires learning rate setting
|
75 |
+
|
76 |
+
---
|
77 |
+
|
78 |
+
## Adam (Adaptive Moment Estimation)
|
79 |
+
|
80 |
+
- **Usage**: Broad range of deep learning tasks
|
81 |
+
- **Strengths**: Efficient, handles noisy/sparse gradients well
|
82 |
+
- **Caveats**: Complex hyperparameter tuning
|
83 |
+
|
84 |
+
---
|
85 |
+
|
86 |
+
## AdamW
|
87 |
+
|
88 |
+
- **Usage**: Regularization heavy tasks
|
89 |
+
- **Strengths**: Better generalization than Adam
|
90 |
+
- **Caveats**: Requires careful tuning of decay terms
|
91 |
+
|
92 |
+
---
|
93 |
+
|
94 |
+
## Conclusion
|
95 |
+
|
96 |
+
Choosing the right optimizer is crucial for training efficiency and model performance.
|
97 |
+
|
98 |
+
Each optimizer has its strengths and is suited for specific types of tasks.
|
99 |
+
|
100 |
+
|
src/theory/training.qmd
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Training
|
3 |
+
---
|
4 |
+
|
5 |
+
Training is a fundamental aspect of developing artificial intelligence (AI) systems. It involves teaching AI models to make predictions or decisions based on data. This process is crucial for AI models to learn from experiences and improve over time. This exhaustive article delves into the various facets of training in AI, including its principles, methods, types of training data, challenges, and best practices.
|
6 |
+
|
7 |
+
[The spelled-out intro to neural networks and backpropagation from Andrej Karpathy](https://www.youtube.com/watch?v=VMj-3S1tku0&t=1734s)
|
8 |
+
|
9 |
+
## Principles of AI Training
|
10 |
+
|
11 |
+
**1. Learning from Data**: At the core of AI training is the principle that models learn from data. The quality, quantity, and relevance of the data directly influence how well an AI model can perform its tasks.
|
12 |
+
|
13 |
+
**2. Generalization**: The ultimate goal of training an AI model is for it to generalize from its training data to new, unseen situations. Generalization ensures that the model performs well not just on the data it was trained on but also on new data.
|
14 |
+
|
15 |
+
**3. Overfitting and Underfitting**: Two common challenges in training AI models are overfitting and underfitting. Overfitting occurs when a model learns the training data too well, including the noise and errors, and performs poorly on new data. Underfitting happens when a model is too simple to learn the underlying pattern of the data.
|
16 |
+
|
17 |
+
## Methods of AI Training
|
18 |
+
|
19 |
+
**1. Supervised Learning**: This is the most common training method where the model learns from a labeled dataset. It tries to learn a function that, given a set of inputs, produces the correct output.
|
20 |
+
|
21 |
+
**2. Unsupervised Learning**: In unsupervised learning, the model learns from data without labels. The goal is to identify patterns and relationships in the data.
|
22 |
+
|
23 |
+
**3. Semi-Supervised Learning**: This method combines a small amount of labeled data with a large amount of unlabeled data during training. It is useful when labeling data is expensive or time-consuming.
|
24 |
+
|
25 |
+
**4. Reinforcement Learning**: Here, an agent learns to make decisions by performing actions in an environment to maximize some notion of cumulative reward.
|
26 |
+
|
27 |
+
**5. Transfer Learning**: In transfer learning, a model developed for a specific task is reused as the starting point for a model on a second task. It is an effective strategy when you have a small amount of data for the second task.
|
28 |
+
|
29 |
+
## Types of Training Data
|
30 |
+
|
31 |
+
**1. Labeled Data**: Data that has been tagged with one or more labels identifying certain properties or classifications used in supervised learning.
|
32 |
+
|
33 |
+
**2. Unlabeled Data**: Data that does not contain labels, used mainly in unsupervised learning setups.
|
34 |
+
|
35 |
+
**3. Synthetic Data**: Artificially created data that mimics real-world data, useful when training data is insufficient or hard to collect.
|
36 |
+
|
37 |
+
**4. Augmented Data**: Real data that has been modified or expanded through techniques such as rotation, scaling, or cropping to improve the robustness of the model.
|
38 |
+
|
39 |
+
## Challenges in AI Training
|
40 |
+
|
41 |
+
**1. Data Quality**: Poor quality data can lead the AI to make incorrect predictions. Ensuring data is clean, representative, and well-prepared is crucial.
|
42 |
+
|
43 |
+
**2. Scalability**: As models and data grow, it becomes challenging to scale training processes efficiently.
|
44 |
+
|
45 |
+
**3. Bias**: AI systems can inadvertently learn and perpetuate biases present in the training data.
|
46 |
+
|
47 |
+
**4. Computational Resources**: Training state-of-the-art AI models often requires significant computational resources, which can be expensive and energy-intensive.
|
48 |
+
|
49 |
+
## Best Practices in AI Training
|
50 |
+
|
51 |
+
**1. Data Preprocessing**: Clean and preprocess data to improve quality and training efficiency.
|
52 |
+
|
53 |
+
**2. Model Selection**: Choose the right model based on the complexity of the task and the nature of the data.
|
54 |
+
|
55 |
+
**3. Regularization Techniques**: Use techniques such as dropout, L1 and L2 regularization to prevent overfitting.
|
56 |
+
|
57 |
+
**4. Cross-validation**: Use cross-validation techniques to ensure the model's performance is consistent across different subsets of the data.
|
58 |
+
|
59 |
+
**5. Continuous Monitoring and Updating**: Regularly update the model with new data to adapt to changes in the underlying data distribution.
|
60 |
+
|
61 |
+
**6. Ethical Considerations**: Address ethical considerations, ensuring the AI system does not harm and works fairly.
|
62 |
+
|
63 |
+
Training is a dynamic and ongoing process in AI development. Understanding its intricacies helps in designing models that are not only accurate but also robust and fair, capable of performing well across a wide range of scenarios. This understanding forms the backbone of successful AI implementations, paving the way for innovative applications that can truly leverage the power of artificial intelligence.
|
64 |
+
|
65 |
+
## Diagnosing the Training Process in AI
|
66 |
+
|
67 |
+
Diagnosing the training process is crucial for developing effective AI models. It involves monitoring the model during training to identify and resolve issues that can negatively impact its performance. Here, we cover how to recognize and troubleshoot common problems like vanishing and exploding gradients, as well as unusual observations in training metrics such as loss and validation loss.
|
68 |
+
|
69 |
+
### Common Training Issues
|
70 |
+
|
71 |
+
**1. Vanishing Gradients**: This occurs when gradients, used in the training process to update weights, become very small, effectively preventing weights from changing their values. As a result, the training process stalls.
|
72 |
+
|
73 |
+
**Causes**:
|
74 |
+
|
75 |
+
- Deep networks with many layers using non-linear activation functions that squish input values into a small range, like the sigmoid or tanh functions.
|
76 |
+
|
77 |
+
- Improperly initialized weights.
|
78 |
+
|
79 |
+
**Diagnosis**:
|
80 |
+
|
81 |
+
- Monitor the gradients during training. If the gradients are consistently near zero, vanishing gradients may be occurring.
|
82 |
+
|
83 |
+
- Use histogram summaries in TensorBoard or similar tools to visualize layer outputs and weights during training.
|
84 |
+
|
85 |
+
**Solutions**:
|
86 |
+
|
87 |
+
- Use ReLU or variants of ReLU which are less likely to cause vanishing gradients because they do not squish large input values.
|
88 |
+
|
89 |
+
- Implement better weight initialization strategies, like He or Glorot initialization.
|
90 |
+
|
91 |
+
- Use Batch Normalization to maintain healthy gradients throughout the network.
|
92 |
+
|
93 |
+
**2. Exploding Gradients**: This problem occurs when gradients grow exponentially through the layers during backpropagation, leading to very large updates to weights and, consequently, an unstable network.
|
94 |
+
|
95 |
+
**Causes**:
|
96 |
+
|
97 |
+
- Deep networks with accumulative multiplication of gradients can lead to exponentially increasing gradients.
|
98 |
+
|
99 |
+
- High learning rates.
|
100 |
+
|
101 |
+
**Diagnosis**:
|
102 |
+
|
103 |
+
- Monitor the gradients. If the gradient values are increasing dramatically over epochs, it’s likely an issue.
|
104 |
+
|
105 |
+
- Watch for NaN values in gradients or weights.
|
106 |
+
|
107 |
+
**Solutions**:
|
108 |
+
|
109 |
+
- Apply gradient clipping to limit the maximum value of gradients during backpropagation.
|
110 |
+
|
111 |
+
- Adjust the learning rate.
|
112 |
+
|
113 |
+
- Use weight regularization techniques, like L2 regularization, to penalize large weights.
|
114 |
+
|
115 |
+
### Unusual Metrics Observations
|
116 |
+
|
117 |
+
**1. High Training Loss or Not Decreasing Loss**: If the loss does not decrease, or decreases very slowly, it indicates the model is not learning effectively.
|
118 |
+
|
119 |
+
**Causes**:
|
120 |
+
|
121 |
+
- Inappropriate model architecture.
|
122 |
+
|
123 |
+
- Inadequate learning rate (too high or too low).
|
124 |
+
|
125 |
+
- Poor quality or insufficient training data.
|
126 |
+
|
127 |
+
**Diagnosis**:
|
128 |
+
|
129 |
+
- Plot the loss over training epochs. A flat line or a line that does not trend downwards indicates a problem.
|
130 |
+
|
131 |
+
- Evaluate learning rate and data quality.
|
132 |
+
|
133 |
+
**Solutions**:
|
134 |
+
|
135 |
+
- Adjust the learning rate.
|
136 |
+
|
137 |
+
- Check and preprocess the training data correctly.
|
138 |
+
|
139 |
+
- Consider changing the model architecture.
|
140 |
+
|
141 |
+
**2. High Variance Between Training and Validation Loss (Overfitting)**: If the training loss decreases but the validation loss does not decrease or increases, the model may be overfitting.
|
142 |
+
|
143 |
+
**Causes**:
|
144 |
+
|
145 |
+
- Model is too complex with too many parameters.
|
146 |
+
|
147 |
+
- Insufficient or non-representative training data.
|
148 |
+
|
149 |
+
**Diagnosis**:
|
150 |
+
|
151 |
+
- Monitor both training and validation loss. A diverging pattern suggests overfitting.
|
152 |
+
|
153 |
+
**Solutions**:
|
154 |
+
|
155 |
+
- Simplify the model by reducing the number of layers or parameters.
|
156 |
+
|
157 |
+
- Use dropout or regularization techniques.
|
158 |
+
|
159 |
+
- Increase training data, or use data augmentation.
|
160 |
+
|
161 |
+
**3. High Bias Between Training and Validation Loss (Underfitting)**: If both training and validation losses are high or the model performs poorly even on training data, the model may be underfitting.
|
162 |
+
|
163 |
+
**Causes**:
|
164 |
+
|
165 |
+
- Overly simple model unable to capture underlying patterns.
|
166 |
+
|
167 |
+
- Inadequate training epochs.
|
168 |
+
|
169 |
+
**Diagnosis**:
|
170 |
+
|
171 |
+
- If both losses are high, consider evaluating the complexity of the model.
|
172 |
+
|
173 |
+
**Solutions**:
|
174 |
+
|
175 |
+
- Increase model complexity by adding more layers or parameters.
|
176 |
+
|
177 |
+
- Train for more epochs.
|
178 |
+
|
179 |
+
- Experiment with different model architectures.
|
180 |
+
|
181 |
+
### Advanced Diagnostics Tools
|
182 |
+
|
183 |
+
**1. Learning Rate Schedulers**: Implement learning rate schedulers to adjust the learning rate during training, which can help in stabilizing the training process.
|
184 |
+
|
185 |
+
**2. Early Stopping**: Use early stopping to terminate training when validation metrics stop improving, preventing overfitting and saving computational resources.
|
186 |
+
|
187 |
+
**3. Regularization Techniques**: Techniques like L1 and L2 regularization penalize large weights, helping control overfitting.
|
188 |
+
|
189 |
+
**4. Hyperparameter Tuning**: Use grid search or random search to optimize hyperparameters like the number of layers, number of neurons, learning rate, etc.
|
190 |
+
|
191 |
+
Proper diagnosis and resolution of training issues are vital for building robust AI systems. By systematically monitoring and adjusting the training process, one can significantly enhance model performance and reliability.
|