|
Facebook AI Similarity Search (Faiss) ํ |
|
================== |
|
# Faiss์ ๋ํ์ฌ |
|
Faiss ๋ Facebook Research๊ฐ ๊ฐ๋ฐํ๋, ๊ณ ๋ฐ๋ ๋ฒกํฐ ์ด์ ๊ฒ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์
๋๋ค. ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ (Approximate Neigbor Search)์ ์ฝ๊ฐ์ ์ ํ์ฑ์ ํฌ์ํ์ฌ ์ ์ฌ ๋ฒกํฐ๋ฅผ ๊ณ ์์ผ๋ก ์ฐพ์ต๋๋ค. |
|
|
|
## RVC์ ์์ด์ Faiss |
|
RVC์์๋ HuBERT๋ก ๋ณํํ feature์ embedding์ ์ํด ํ๋ จ ๋ฐ์ดํฐ์์ ์์ฑ๋ embedding๊ณผ ์ ์ฌํ embadding์ ๊ฒ์ํ๊ณ ํผํฉํ์ฌ ์๋์ ์์ฑ์ ๋์ฑ ๊ฐ๊น์ด ๋ณํ์ ๋ฌ์ฑํฉ๋๋ค. ๊ทธ๋ฌ๋, ์ด ํ์๋ฒ์ ๋จ์ํ ์ํํ๋ฉด ์๊ฐ์ด ๋ค์ ์๋ชจ๋๋ฏ๋ก, ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ํตํด ๊ณ ์ ๋ณํ์ ๊ฐ๋ฅ์ผ ํ๊ณ ์์ต๋๋ค. |
|
|
|
# ๊ตฌํ ๊ฐ์ |
|
๋ชจ๋ธ์ด ์์นํ `/logs/your-experiment/3_feature256`์๋ ๊ฐ ์์ฑ ๋ฐ์ดํฐ์์ HuBERT๊ฐ ์ถ์ถํ feature๋ค์ด ์์ต๋๋ค. ์ฌ๊ธฐ์์ ํ์ผ ์ด๋ฆ๋ณ๋ก ์ ๋ ฌ๋ npy ํ์ผ์ ์ฝ๊ณ , ๋ฒกํฐ๋ฅผ ์ฐ๊ฒฐํ์ฌ big_npy ([N, 256] ๋ชจ์์ ๋ฒกํฐ) ๋ฅผ ๋ง๋ญ๋๋ค. big_npy๋ฅผ `/logs/your-experiment/total_fea.npy`๋ก ์ ์ฅํ ํ, Faiss๋ก ํ์ต์ํต๋๋ค. |
|
|
|
2023/04/18 ๊ธฐ์ค์ผ๋ก, Faiss์ Index Factory ๊ธฐ๋ฅ์ ์ด์ฉํด, L2 ๊ฑฐ๋ฆฌ์ ๊ทผ๊ฑฐํ๋ IVF๋ฅผ ์ด์ฉํ๊ณ ์์ต๋๋ค. IVF์ ๋ถํ ์(n_ivf)๋ N//39๋ก, n_probe๋ int(np.power(n_ivf, 0.3))๊ฐ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. (infer-web.py์ train_index ์ฃผ์๋ฅผ ์ฐพ์ผ์ญ์์ค.) |
|
|
|
์ด ํ์์๋ ๋จผ์ ์ด๋ฌํ ๋งค๊ฐ ๋ณ์์ ์๋ฏธ๋ฅผ ์ค๋ช
ํ๊ณ , ๊ฐ๋ฐ์๊ฐ ์ถํ ๋ ๋์ index๋ฅผ ์์ฑํ ์ ์๋๋ก ํ๋ ์กฐ์ธ์ ์์ฑํฉ๋๋ค. |
|
|
|
# ๋ฐฉ๋ฒ์ ์ค๋ช
|
|
## Index factory |
|
index factory๋ ์ฌ๋ฌ ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ๋ฌธ์์ด๋ก ์ฐ๊ฒฐํ๋ pipeline์ ๋ฌธ์์ด๋ก ํ๊ธฐํ๋ Faiss๋ง์ ๋
์์ ์ธ ๊ธฐ๋ฒ์
๋๋ค. ์ด๋ฅผ ํตํด index factory์ ๋ฌธ์์ด์ ๋ณ๊ฒฝํ๋ ๊ฒ๋ง์ผ๋ก ๋ค์ํ ๊ทผ์ฌ ๊ทผ์ ํ์์ ์๋ํด ๋ณผ ์ ์์ต๋๋ค. RVC์์๋ ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉ๋ฉ๋๋ค: |
|
|
|
```python |
|
index = Faiss.index_factory(256, "IVF%s,Flat" % n_ivf) |
|
``` |
|
`index_factory`์ ์ธ์๋ค ์ค ์ฒซ ๋ฒ์งธ๋ ๋ฒกํฐ์ ์ฐจ์ ์์ด๊ณ , ๋๋ฒ์งธ๋ index factory ๋ฌธ์์ด์ด๋ฉฐ, ์ธ๋ฒ์งธ์๋ ์ฌ์ฉํ ๊ฑฐ๋ฆฌ๋ฅผ ์ง์ ํ ์ ์์ต๋๋ค. |
|
|
|
๊ธฐ๋ฒ์ ๋ณด๋ค ์์ธํ ์ค๋ช
์ https://github.com/facebookresearch/Faiss/wiki/The-index-factory ๋ฅผ ํ์ธํด ์ฃผ์ญ์์ค. |
|
|
|
## ๊ฑฐ๋ฆฌ์ ๋ํ index |
|
embedding์ ์ ์ฌ๋๋ก์ ์ฌ์ฉ๋๋ ๋ํ์ ์ธ ์งํ๋ก์ ์ดํ์ 2๊ฐ๊ฐ ์์ต๋๋ค. |
|
|
|
- ์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ (METRIC_L2) |
|
- ๋ด์ (ๅ
็ฉ) (METRIC_INNER_PRODUCT) |
|
|
|
์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ์์๋ ๊ฐ ์ฐจ์์์ ์ ๊ณฑ์ ์ฐจ๋ฅผ ๊ตฌํ๊ณ , ๊ฐ ์ฐจ์์์ ๊ตฌํ ์ฐจ๋ฅผ ๋ชจ๋ ๋ํ ํ ์ ๊ณฑ๊ทผ์ ์ทจํฉ๋๋ค. ์ด๊ฒ์ ์ผ์์ ์ผ๋ก ์ฌ์ฉ๋๋ 2์ฐจ์, 3์ฐจ์์์์ ๊ฑฐ๋ฆฌ์ ์ฐ์ฐ๋ฒ๊ณผ ๊ฐ์ต๋๋ค. ๋ด์ ์ ๊ทธ ๊ฐ์ ๊ทธ๋๋ก ์ ์ฌ๋ ์งํ๋ก ์ฌ์ฉํ์ง ์๊ณ , L2 ์ ๊ทํ๋ฅผ ํ ์ดํ ๋ด์ ์ ์ทจํ๋ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค. |
|
|
|
์ด๋ ์ชฝ์ด ๋ ์ข์์ง๋ ๊ฒฝ์ฐ์ ๋ฐ๋ผ ๋ค๋ฅด์ง๋ง, word2vec์์ ์ป์ embedding ๋ฐ ArcFace๋ฅผ ํ์ฉํ ์ด๋ฏธ์ง ๊ฒ์ ๋ชจ๋ธ์ ์ฝ์ฌ์ธ ์ ์ฌ์ฑ์ด ์ด์ฉ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. numpy๋ฅผ ์ฌ์ฉํ์ฌ ๋ฒกํฐ X์ ๋ํด L2 ์ ๊ทํ๋ฅผ ํ๊ณ ์ ํ๋ ๊ฒฝ์ฐ, 0 division์ ํผํ๊ธฐ ์ํด ์ถฉ๋ถํ ์์ ๊ฐ์ eps๋ก ํ ๋ค ์ดํ์ ์ฝ๋๋ฅผ ํ์ฉํ๋ฉด ๋ฉ๋๋ค. |
|
|
|
```python |
|
X_normed = X / np.maximum(eps, np.linalg.norm(X, ord=2, axis=-1, keepdims=True)) |
|
``` |
|
|
|
๋ํ, `index factory`์ 3๋ฒ์งธ ์ธ์์ ๊ฑด๋ค์ฃผ๋ ๊ฐ์ ์ ํํ๋ ๊ฒ์ ํตํด ๊ณ์ฐ์ ์ฌ์ฉํ๋ ๊ฑฐ๋ฆฌ index๋ฅผ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค. |
|
|
|
```python |
|
index = Faiss.index_factory(dimention, text, Faiss.METRIC_INNER_PRODUCT) |
|
``` |
|
|
|
## IVF |
|
IVF (Inverted file indexes)๋ ์ญ์์ธ ํ์๋ฒ๊ณผ ์ ์ฌํ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค. ํ์ต์์๋ ๊ฒ์ ๋์์ ๋ํด k-ํ๊ท ๊ตฐ์ง๋ฒ์ ์ค์ํ๊ณ ํด๋ฌ์คํฐ ์ค์ฌ์ ์ด์ฉํด ๋ณด๋ก๋
ธ์ด ๋ถํ ์ ์ค์ํฉ๋๋ค. ๊ฐ ๋ฐ์ดํฐ ํฌ์ธํธ์๋ ํด๋ฌ์คํฐ๊ฐ ํ ๋น๋๋ฏ๋ก, ํด๋ฌ์คํฐ์์ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ์กฐํํ๋ dictionary๋ฅผ ๋ง๋ญ๋๋ค. |
|
|
|
์๋ฅผ ๋ค์ด, ํด๋ฌ์คํฐ๊ฐ ๋ค์๊ณผ ๊ฐ์ด ํ ๋น๋ ๊ฒฝ์ฐ |
|
|index|Cluster| |
|
|-----|-------| |
|
|1|A| |
|
|2|B| |
|
|3|A| |
|
|4|C| |
|
|5|B| |
|
|
|
IVF ์ดํ์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค: |
|
|
|
|cluster|index| |
|
|-------|-----| |
|
|A|1, 3| |
|
|B|2, 5| |
|
|C|4| |
|
|
|
ํ์ ์, ์ฐ์ ํด๋ฌ์คํฐ์์ `n_probe`๊ฐ์ ํด๋ฌ์คํฐ๋ฅผ ํ์ํ ๋ค์, ๊ฐ ํด๋ฌ์คํฐ์ ์ํ ๋ฐ์ดํฐ ํฌ์ธํธ์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํฉ๋๋ค. |
|
|
|
# ๊ถ์ฅ ๋งค๊ฐ๋ณ์ |
|
index์ ์ ํ ๋ฐฉ๋ฒ์ ๋ํด์๋ ๊ณต์์ ์ผ๋ก ๊ฐ์ด๋ ๋ผ์ธ์ด ์์ผ๋ฏ๋ก, ๊ฑฐ๊ธฐ์ ์คํด ์ค๋ช
ํฉ๋๋ค. |
|
https://github.com/facebookresearch/Faiss/wiki/Guidelines-to-choose-an-index |
|
|
|
1M ์ดํ์ ๋ฐ์ดํฐ ์ธํธ์ ์์ด์๋ 4bit-PQ๊ฐ 2023๋
4์ ์์ ์์๋ Faiss๋ก ์ด์ฉํ ์ ์๋ ๊ฐ์ฅ ํจ์จ์ ์ธ ์๋ฒ์
๋๋ค. ์ด๊ฒ์ IVF์ ์กฐํฉํด, 4bit-PQ๋ก ํ๋ณด๋ฅผ ์ถ๋ ค๋ด๊ณ , ๋ง์ง๋ง์ผ๋ก ์ดํ์ index factory๋ฅผ ์ด์ฉํ์ฌ ์ ํํ ์งํ๋ก ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ๊ณ์ฐํ๋ฉด ๋ฉ๋๋ค. |
|
|
|
```python |
|
index = Faiss.index_factory(256, "IVF1024,PQ128x4fs,RFlat") |
|
``` |
|
|
|
## IVF ๊ถ์ฅ ๋งค๊ฐ๋ณ์ |
|
IVF์ ์๊ฐ ๋๋ฌด ๋ง์ผ๋ฉด, ๊ฐ๋ น ๋ฐ์ดํฐ ์์ ์๋งํผ IVF๋ก ์์ํ(Quantization)๋ฅผ ์ํํ๋ฉด, ์ด๊ฒ์ ์์ ํ์๊ณผ ๊ฐ์์ ธ ํจ์จ์ด ๋๋น ์ง๊ฒ ๋ฉ๋๋ค. 1M ์ดํ์ ๊ฒฝ์ฐ IVF ๊ฐ์ ๋ฐ์ดํฐ ํฌ์ธํธ ์ N์ ๋ํด 4sqrt(N) ~ 16sqrt(N)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. |
|
|
|
n_probe๋ n_probe์ ์์ ๋น๋กํ์ฌ ๊ณ์ฐ ์๊ฐ์ด ๋์ด๋๋ฏ๋ก ์ ํ๋์ ์๊ฐ์ ์ ์ ํ ๊ท ํ์ ๋ง์ถ์ด ์ฃผ์ญ์์ค. ๊ฐ์ธ์ ์ผ๋ก RVC์ ์์ด์ ๊ทธ๋ ๊ฒ๊น์ง ์ ํ๋๋ ํ์ ์๋ค๊ณ ์๊ฐํ๊ธฐ ๋๋ฌธ์ n_probe = 1์ด๋ฉด ๋๋ค๊ณ ์๊ฐํฉ๋๋ค. |
|
|
|
## FastScan |
|
FastScan์ ์ง์ ์์ํ๋ฅผ ๋ ์ง์คํฐ์์ ์ํํจ์ผ๋ก์จ ๊ฑฐ๋ฆฌ์ ๊ณ ์ ๊ทผ์ฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์
๋๋ค.์ง์ ์์ํ๋ ํ์ต์์ d์ฐจ์๋ง๋ค(๋ณดํต d=2)์ ๋
๋ฆฝ์ ์ผ๋ก ํด๋ฌ์คํฐ๋ง์ ์ค์ํด, ํด๋ฌ์คํฐ๋ผ๋ฆฌ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ์ ๊ณ์ฐํด lookup table๋ฅผ ์์ฑํฉ๋๋ค. ์์ธก์๋ lookup table์ ๋ณด๋ฉด ๊ฐ ์ฐจ์์ ๊ฑฐ๋ฆฌ๋ฅผ O(1)๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ PQ ๋ค์์ ์ง์ ํ๋ ์ซ์๋ ์ผ๋ฐ์ ์ผ๋ก ๋ฒกํฐ์ ์ ๋ฐ ์ฐจ์์ ์ง์ ํฉ๋๋ค. |
|
|
|
FastScan์ ๋ํ ์์ธํ ์ค๋ช
์ ๊ณต์ ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ญ์์ค. |
|
https://github.com/facebookresearch/Faiss/wiki/Fast-accumulation-of-PQ-and-AQ-codes-(FastScan) |
|
|
|
## RFlat |
|
RFlat์ FastScan์ด ๊ณ์ฐํ ๋๋ต์ ์ธ ๊ฑฐ๋ฆฌ๋ฅผ index factory์ 3๋ฒ์งธ ์ธ์๋ก ์ง์ ํ ์ ํํ ๊ฑฐ๋ฆฌ๋ก ๋ค์ ๊ณ์ฐํ๋ผ๋ ์ธ์คํธ๋ญ์
์
๋๋ค. k๊ฐ์ ๊ทผ์ ๋ณ์๋ฅผ ๊ฐ์ ธ์ฌ ๋ k*k_factor๊ฐ์ ์ ์ ๋ํด ์ฌ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค. |
|
|
|
# Embedding ํ
ํฌ๋ |
|
## Alpha ์ฟผ๋ฆฌ ํ์ฅ |
|
ํด๋ฆฌ ํ์ฅ์ด๋ ํ์์์ ์ฌ์ฉ๋๋ ๊ธฐ์ ๋ก, ์๋ฅผ ๋ค์ด ์ ๋ฌธ ํ์ ์, ์
๋ ฅ๋ ๊ฒ์๋ฌธ์ ๋จ์ด๋ฅผ ๋ช ๊ฐ๋ฅผ ์ถ๊ฐํจ์ผ๋ก์จ ๊ฒ์ ์ ํ๋๋ฅผ ์ฌ๋ฆฌ๋ ๋ฐฉ๋ฒ์
๋๋ค. ๋ฐฑํฐ ํ์์ ์ํด์๋ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ์ ์๋์๋๋ฐ, ๊ทธ ์ค ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ์ถ๊ฐ ํ์ต์ด ํ์ ์๋ ๋งค์ฐ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. [Attention-Based Query Expansion Learning](https://arxiv.org/abs/2007.08019)์ [2nd place solution of kaggle shopee competition](https://www.kaggle.com/code/lyakaap/2nd-place-solution/notebook) ๋
ผ๋ฌธ์์ ์๊ฐ๋ ๋ฐ ์์ต๋๋ค.. |
|
|
|
ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ํ ๋ฒกํฐ์ ์ธ์ ํ ๋ฒกํฐ๋ฅผ ์ ์ฌ๋์ ฮฑ๊ณฑํ ๊ฐ์ค์น๋ก ๋ํด์ฃผ๋ฉด ๋ฉ๋๋ค. ์ฝ๋๋ก ์์๋ฅผ ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค. big_npy๋ฅผ ฮฑ query expansion๋ก ๋์ฒดํฉ๋๋ค. |
|
|
|
```python |
|
alpha = 3. |
|
index = Faiss.index_factory(256, "IVF512,PQ128x4fs,RFlat") |
|
original_norm = np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9) |
|
big_npy /= original_norm |
|
index.train(big_npy) |
|
index.add(big_npy) |
|
dist, neighbor = index.search(big_npy, num_expand) |
|
|
|
expand_arrays = [] |
|
ixs = np.arange(big_npy.shape[0]) |
|
for i in range(-(-big_npy.shape[0]//batch_size)): |
|
ix = ixs[i*batch_size:(i+1)*batch_size] |
|
weight = np.power(np.einsum("nd,nmd->nm", big_npy[ix], big_npy[neighbor[ix]]), alpha) |
|
expand_arrays.append(np.sum(big_npy[neighbor[ix]] * np.expand_dims(weight, axis=2),axis=1)) |
|
big_npy = np.concatenate(expand_arrays, axis=0) |
|
|
|
# index version ์ ๊ทํ |
|
big_npy = big_npy / np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9) |
|
``` |
|
|
|
์ ํ
ํฌ๋์ ํ์์ ์ํํ๋ ์ฟผ๋ฆฌ์๋, ํ์ ๋์ DB์๋ ์ ์ ๊ฐ๋ฅํ ํ
ํฌ๋์
๋๋ค. |
|
|
|
## MiniBatch KMeans์ ์ํ embedding ์์ถ |
|
|
|
total_fea.npy๊ฐ ๋๋ฌด ํด ๊ฒฝ์ฐ K-means๋ฅผ ์ด์ฉํ์ฌ ๋ฒกํฐ๋ฅผ ์๊ฒ ๋ง๋๋ ๊ฒ์ด ๊ฐ๋ฅํฉ๋๋ค. ์ดํ ์ฝ๋๋ก embedding์ ์์ถ์ด ๊ฐ๋ฅํฉ๋๋ค. n_clusters์ ์์ถํ๊ณ ์ ํ๋ ํฌ๊ธฐ๋ฅผ ์ง์ ํ๊ณ batch_size์ 256 * CPU์ ์ฝ์ด ์๋ฅผ ์ง์ ํจ์ผ๋ก์จ CPU ๋ณ๋ ฌํ์ ํํ์ ์ถฉ๋ถํ ์ป์ ์ ์์ต๋๋ค. |
|
|
|
```python |
|
import multiprocessing |
|
from sklearn.cluster import MiniBatchKMeans |
|
kmeans = MiniBatchKMeans(n_clusters=10000, batch_size=256 * multiprocessing.cpu_count(), init="random") |
|
kmeans.fit(big_npy) |
|
sample_npy = kmeans.cluster_centers_ |
|
``` |