Facebook AI Similarity Search (Faiss) ํ
Faiss์ ๋ํ์ฌ
Faiss ๋ Facebook Research๊ฐ ๊ฐ๋ฐํ๋, ๊ณ ๋ฐ๋ ๋ฒกํฐ ์ด์ ๊ฒ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ค. ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ (Approximate Neigbor Search)์ ์ฝ๊ฐ์ ์ ํ์ฑ์ ํฌ์ํ์ฌ ์ ์ฌ ๋ฒกํฐ๋ฅผ ๊ณ ์์ผ๋ก ์ฐพ์ต๋๋ค.
RVC์ ์์ด์ Faiss
RVC์์๋ HuBERT๋ก ๋ณํํ feature์ embedding์ ์ํด ํ๋ จ ๋ฐ์ดํฐ์์ ์์ฑ๋ embedding๊ณผ ์ ์ฌํ embadding์ ๊ฒ์ํ๊ณ ํผํฉํ์ฌ ์๋์ ์์ฑ์ ๋์ฑ ๊ฐ๊น์ด ๋ณํ์ ๋ฌ์ฑํฉ๋๋ค. ๊ทธ๋ฌ๋, ์ด ํ์๋ฒ์ ๋จ์ํ ์ํํ๋ฉด ์๊ฐ์ด ๋ค์ ์๋ชจ๋๋ฏ๋ก, ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ํตํด ๊ณ ์ ๋ณํ์ ๊ฐ๋ฅ์ผ ํ๊ณ ์์ต๋๋ค.
๊ตฌํ ๊ฐ์
๋ชจ๋ธ์ด ์์นํ /logs/your-experiment/3_feature256
์๋ ๊ฐ ์์ฑ ๋ฐ์ดํฐ์์ HuBERT๊ฐ ์ถ์ถํ feature๋ค์ด ์์ต๋๋ค. ์ฌ๊ธฐ์์ ํ์ผ ์ด๋ฆ๋ณ๋ก ์ ๋ ฌ๋ npy ํ์ผ์ ์ฝ๊ณ , ๋ฒกํฐ๋ฅผ ์ฐ๊ฒฐํ์ฌ big_npy ([N, 256] ๋ชจ์์ ๋ฒกํฐ) ๋ฅผ ๋ง๋ญ๋๋ค. big_npy๋ฅผ /logs/your-experiment/total_fea.npy
๋ก ์ ์ฅํ ํ, Faiss๋ก ํ์ต์ํต๋๋ค.
2023/04/18 ๊ธฐ์ค์ผ๋ก, Faiss์ Index Factory ๊ธฐ๋ฅ์ ์ด์ฉํด, L2 ๊ฑฐ๋ฆฌ์ ๊ทผ๊ฑฐํ๋ IVF๋ฅผ ์ด์ฉํ๊ณ ์์ต๋๋ค. IVF์ ๋ถํ ์(n_ivf)๋ N//39๋ก, n_probe๋ int(np.power(n_ivf, 0.3))๊ฐ ์ฌ์ฉ๋๊ณ ์์ต๋๋ค. (web.py์ train_index ์ฃผ์๋ฅผ ์ฐพ์ผ์ญ์์ค.)
์ด ํ์์๋ ๋จผ์ ์ด๋ฌํ ๋งค๊ฐ ๋ณ์์ ์๋ฏธ๋ฅผ ์ค๋ช ํ๊ณ , ๊ฐ๋ฐ์๊ฐ ์ถํ ๋ ๋์ index๋ฅผ ์์ฑํ ์ ์๋๋ก ํ๋ ์กฐ์ธ์ ์์ฑํฉ๋๋ค.
๋ฐฉ๋ฒ์ ์ค๋ช
Index factory
index factory๋ ์ฌ๋ฌ ๊ทผ์ฌ ๊ทผ์ ํ์๋ฒ์ ๋ฌธ์์ด๋ก ์ฐ๊ฒฐํ๋ pipeline์ ๋ฌธ์์ด๋ก ํ๊ธฐํ๋ Faiss๋ง์ ๋ ์์ ์ธ ๊ธฐ๋ฒ์ ๋๋ค. ์ด๋ฅผ ํตํด index factory์ ๋ฌธ์์ด์ ๋ณ๊ฒฝํ๋ ๊ฒ๋ง์ผ๋ก ๋ค์ํ ๊ทผ์ฌ ๊ทผ์ ํ์์ ์๋ํด ๋ณผ ์ ์์ต๋๋ค. RVC์์๋ ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉ๋ฉ๋๋ค:
index = Faiss.index_factory(256, "IVF%s,Flat" % n_ivf)
index_factory
์ ์ธ์๋ค ์ค ์ฒซ ๋ฒ์งธ๋ ๋ฒกํฐ์ ์ฐจ์ ์์ด๊ณ , ๋๋ฒ์งธ๋ index factory ๋ฌธ์์ด์ด๋ฉฐ, ์ธ๋ฒ์งธ์๋ ์ฌ์ฉํ ๊ฑฐ๋ฆฌ๋ฅผ ์ง์ ํ ์ ์์ต๋๋ค.
๊ธฐ๋ฒ์ ๋ณด๋ค ์์ธํ ์ค๋ช ์ https://github.com/facebookresearch/Faiss/wiki/The-index-factory ๋ฅผ ํ์ธํด ์ฃผ์ญ์์ค.
๊ฑฐ๋ฆฌ์ ๋ํ index
embedding์ ์ ์ฌ๋๋ก์ ์ฌ์ฉ๋๋ ๋ํ์ ์ธ ์งํ๋ก์ ์ดํ์ 2๊ฐ๊ฐ ์์ต๋๋ค.
- ์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ (METRIC_L2)
- ๋ด์ (ๅ ็ฉ) (METRIC_INNER_PRODUCT)
์ ํด๋ฆฌ๋ ๊ฑฐ๋ฆฌ์์๋ ๊ฐ ์ฐจ์์์ ์ ๊ณฑ์ ์ฐจ๋ฅผ ๊ตฌํ๊ณ , ๊ฐ ์ฐจ์์์ ๊ตฌํ ์ฐจ๋ฅผ ๋ชจ๋ ๋ํ ํ ์ ๊ณฑ๊ทผ์ ์ทจํฉ๋๋ค. ์ด๊ฒ์ ์ผ์์ ์ผ๋ก ์ฌ์ฉ๋๋ 2์ฐจ์, 3์ฐจ์์์์ ๊ฑฐ๋ฆฌ์ ์ฐ์ฐ๋ฒ๊ณผ ๊ฐ์ต๋๋ค. ๋ด์ ์ ๊ทธ ๊ฐ์ ๊ทธ๋๋ก ์ ์ฌ๋ ์งํ๋ก ์ฌ์ฉํ์ง ์๊ณ , L2 ์ ๊ทํ๋ฅผ ํ ์ดํ ๋ด์ ์ ์ทจํ๋ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ์ฌ์ฉํฉ๋๋ค.
์ด๋ ์ชฝ์ด ๋ ์ข์์ง๋ ๊ฒฝ์ฐ์ ๋ฐ๋ผ ๋ค๋ฅด์ง๋ง, word2vec์์ ์ป์ embedding ๋ฐ ArcFace๋ฅผ ํ์ฉํ ์ด๋ฏธ์ง ๊ฒ์ ๋ชจ๋ธ์ ์ฝ์ฌ์ธ ์ ์ฌ์ฑ์ด ์ด์ฉ๋๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. numpy๋ฅผ ์ฌ์ฉํ์ฌ ๋ฒกํฐ X์ ๋ํด L2 ์ ๊ทํ๋ฅผ ํ๊ณ ์ ํ๋ ๊ฒฝ์ฐ, 0 division์ ํผํ๊ธฐ ์ํด ์ถฉ๋ถํ ์์ ๊ฐ์ eps๋ก ํ ๋ค ์ดํ์ ์ฝ๋๋ฅผ ํ์ฉํ๋ฉด ๋ฉ๋๋ค.
X_normed = X / np.maximum(eps, np.linalg.norm(X, ord=2, axis=-1, keepdims=True))
๋ํ, index factory
์ 3๋ฒ์งธ ์ธ์์ ๊ฑด๋ค์ฃผ๋ ๊ฐ์ ์ ํํ๋ ๊ฒ์ ํตํด ๊ณ์ฐ์ ์ฌ์ฉํ๋ ๊ฑฐ๋ฆฌ index๋ฅผ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
index = Faiss.index_factory(dimention, text, Faiss.METRIC_INNER_PRODUCT)
IVF
IVF (Inverted file indexes)๋ ์ญ์์ธ ํ์๋ฒ๊ณผ ์ ์ฌํ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. ํ์ต์์๋ ๊ฒ์ ๋์์ ๋ํด k-ํ๊ท ๊ตฐ์ง๋ฒ์ ์ค์ํ๊ณ ํด๋ฌ์คํฐ ์ค์ฌ์ ์ด์ฉํด ๋ณด๋ก๋ ธ์ด ๋ถํ ์ ์ค์ํฉ๋๋ค. ๊ฐ ๋ฐ์ดํฐ ํฌ์ธํธ์๋ ํด๋ฌ์คํฐ๊ฐ ํ ๋น๋๋ฏ๋ก, ํด๋ฌ์คํฐ์์ ๋ฐ์ดํฐ ํฌ์ธํธ๋ฅผ ์กฐํํ๋ dictionary๋ฅผ ๋ง๋ญ๋๋ค.
์๋ฅผ ๋ค์ด, ํด๋ฌ์คํฐ๊ฐ ๋ค์๊ณผ ๊ฐ์ด ํ ๋น๋ ๊ฒฝ์ฐ
index | Cluster |
---|---|
1 | A |
2 | B |
3 | A |
4 | C |
5 | B |
IVF ์ดํ์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
cluster | index |
---|---|
A | 1, 3 |
B | 2, 5 |
C | 4 |
ํ์ ์, ์ฐ์ ํด๋ฌ์คํฐ์์ n_probe
๊ฐ์ ํด๋ฌ์คํฐ๋ฅผ ํ์ํ ๋ค์, ๊ฐ ํด๋ฌ์คํฐ์ ์ํ ๋ฐ์ดํฐ ํฌ์ธํธ์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
๊ถ์ฅ ๋งค๊ฐ๋ณ์
index์ ์ ํ ๋ฐฉ๋ฒ์ ๋ํด์๋ ๊ณต์์ ์ผ๋ก ๊ฐ์ด๋ ๋ผ์ธ์ด ์์ผ๋ฏ๋ก, ๊ฑฐ๊ธฐ์ ์คํด ์ค๋ช ํฉ๋๋ค. https://github.com/facebookresearch/Faiss/wiki/Guidelines-to-choose-an-index
1M ์ดํ์ ๋ฐ์ดํฐ ์ธํธ์ ์์ด์๋ 4bit-PQ๊ฐ 2023๋ 4์ ์์ ์์๋ Faiss๋ก ์ด์ฉํ ์ ์๋ ๊ฐ์ฅ ํจ์จ์ ์ธ ์๋ฒ์ ๋๋ค. ์ด๊ฒ์ IVF์ ์กฐํฉํด, 4bit-PQ๋ก ํ๋ณด๋ฅผ ์ถ๋ ค๋ด๊ณ , ๋ง์ง๋ง์ผ๋ก ์ดํ์ index factory๋ฅผ ์ด์ฉํ์ฌ ์ ํํ ์งํ๋ก ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ๊ณ์ฐํ๋ฉด ๋ฉ๋๋ค.
index = Faiss.index_factory(256, "IVF1024,PQ128x4fs,RFlat")
IVF ๊ถ์ฅ ๋งค๊ฐ๋ณ์
IVF์ ์๊ฐ ๋๋ฌด ๋ง์ผ๋ฉด, ๊ฐ๋ น ๋ฐ์ดํฐ ์์ ์๋งํผ IVF๋ก ์์ํ(Quantization)๋ฅผ ์ํํ๋ฉด, ์ด๊ฒ์ ์์ ํ์๊ณผ ๊ฐ์์ ธ ํจ์จ์ด ๋๋น ์ง๊ฒ ๋ฉ๋๋ค. 1M ์ดํ์ ๊ฒฝ์ฐ IVF ๊ฐ์ ๋ฐ์ดํฐ ํฌ์ธํธ ์ N์ ๋ํด 4sqrt(N) ~ 16sqrt(N)๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
n_probe๋ n_probe์ ์์ ๋น๋กํ์ฌ ๊ณ์ฐ ์๊ฐ์ด ๋์ด๋๋ฏ๋ก ์ ํ๋์ ์๊ฐ์ ์ ์ ํ ๊ท ํ์ ๋ง์ถ์ด ์ฃผ์ญ์์ค. ๊ฐ์ธ์ ์ผ๋ก RVC์ ์์ด์ ๊ทธ๋ ๊ฒ๊น์ง ์ ํ๋๋ ํ์ ์๋ค๊ณ ์๊ฐํ๊ธฐ ๋๋ฌธ์ n_probe = 1์ด๋ฉด ๋๋ค๊ณ ์๊ฐํฉ๋๋ค.
FastScan
FastScan์ ์ง์ ์์ํ๋ฅผ ๋ ์ง์คํฐ์์ ์ํํจ์ผ๋ก์จ ๊ฑฐ๋ฆฌ์ ๊ณ ์ ๊ทผ์ฌ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.์ง์ ์์ํ๋ ํ์ต์์ d์ฐจ์๋ง๋ค(๋ณดํต d=2)์ ๋ ๋ฆฝ์ ์ผ๋ก ํด๋ฌ์คํฐ๋ง์ ์ค์ํด, ํด๋ฌ์คํฐ๋ผ๋ฆฌ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ฌ์ ๊ณ์ฐํด lookup table๋ฅผ ์์ฑํฉ๋๋ค. ์์ธก์๋ lookup table์ ๋ณด๋ฉด ๊ฐ ์ฐจ์์ ๊ฑฐ๋ฆฌ๋ฅผ O(1)๋ก ๊ณ์ฐํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ PQ ๋ค์์ ์ง์ ํ๋ ์ซ์๋ ์ผ๋ฐ์ ์ผ๋ก ๋ฒกํฐ์ ์ ๋ฐ ์ฐจ์์ ์ง์ ํฉ๋๋ค.
FastScan์ ๋ํ ์์ธํ ์ค๋ช ์ ๊ณต์ ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ญ์์ค. https://github.com/facebookresearch/Faiss/wiki/Fast-accumulation-of-PQ-and-AQ-codes-(FastScan)
RFlat
RFlat์ FastScan์ด ๊ณ์ฐํ ๋๋ต์ ์ธ ๊ฑฐ๋ฆฌ๋ฅผ index factory์ 3๋ฒ์งธ ์ธ์๋ก ์ง์ ํ ์ ํํ ๊ฑฐ๋ฆฌ๋ก ๋ค์ ๊ณ์ฐํ๋ผ๋ ์ธ์คํธ๋ญ์ ์ ๋๋ค. k๊ฐ์ ๊ทผ์ ๋ณ์๋ฅผ ๊ฐ์ ธ์ฌ ๋ k*k_factor๊ฐ์ ์ ์ ๋ํด ์ฌ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค.
Embedding ํ ํฌ๋
Alpha ์ฟผ๋ฆฌ ํ์ฅ
ํด๋ฆฌ ํ์ฅ์ด๋ ํ์์์ ์ฌ์ฉ๋๋ ๊ธฐ์ ๋ก, ์๋ฅผ ๋ค์ด ์ ๋ฌธ ํ์ ์, ์ ๋ ฅ๋ ๊ฒ์๋ฌธ์ ๋จ์ด๋ฅผ ๋ช ๊ฐ๋ฅผ ์ถ๊ฐํจ์ผ๋ก์จ ๊ฒ์ ์ ํ๋๋ฅผ ์ฌ๋ฆฌ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ฐฑํฐ ํ์์ ์ํด์๋ ๋ช๊ฐ์ง ๋ฐฉ๋ฒ์ด ์ ์๋์๋๋ฐ, ๊ทธ ์ค ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ์ถ๊ฐ ํ์ต์ด ํ์ ์๋ ๋งค์ฐ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. Attention-Based Query Expansion Learning์ 2nd place solution of kaggle shopee competition ๋ ผ๋ฌธ์์ ์๊ฐ๋ ๋ฐ ์์ต๋๋ค..
ฮฑ-์ฟผ๋ฆฌ ํ์ฅ์ ํ ๋ฒกํฐ์ ์ธ์ ํ ๋ฒกํฐ๋ฅผ ์ ์ฌ๋์ ฮฑ๊ณฑํ ๊ฐ์ค์น๋ก ๋ํด์ฃผ๋ฉด ๋ฉ๋๋ค. ์ฝ๋๋ก ์์๋ฅผ ๋ค์ด ๋ณด๊ฒ ์ต๋๋ค. big_npy๋ฅผ ฮฑ query expansion๋ก ๋์ฒดํฉ๋๋ค.
alpha = 3.
index = Faiss.index_factory(256, "IVF512,PQ128x4fs,RFlat")
original_norm = np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)
big_npy /= original_norm
index.train(big_npy)
index.add(big_npy)
dist, neighbor = index.search(big_npy, num_expand)
expand_arrays = []
ixs = np.arange(big_npy.shape[0])
for i in range(-(-big_npy.shape[0]//batch_size)):
ix = ixs[i*batch_size:(i+1)*batch_size]
weight = np.power(np.einsum("nd,nmd->nm", big_npy[ix], big_npy[neighbor[ix]]), alpha)
expand_arrays.append(np.sum(big_npy[neighbor[ix]] * np.expand_dims(weight, axis=2),axis=1))
big_npy = np.concatenate(expand_arrays, axis=0)
# index version ์ ๊ทํ
big_npy = big_npy / np.maximum(np.linalg.norm(big_npy, ord=2, axis=1, keepdims=True), 1e-9)
์ ํ ํฌ๋์ ํ์์ ์ํํ๋ ์ฟผ๋ฆฌ์๋, ํ์ ๋์ DB์๋ ์ ์ ๊ฐ๋ฅํ ํ ํฌ๋์ ๋๋ค.
MiniBatch KMeans์ ์ํ embedding ์์ถ
total_fea.npy๊ฐ ๋๋ฌด ํด ๊ฒฝ์ฐ K-means๋ฅผ ์ด์ฉํ์ฌ ๋ฒกํฐ๋ฅผ ์๊ฒ ๋ง๋๋ ๊ฒ์ด ๊ฐ๋ฅํฉ๋๋ค. ์ดํ ์ฝ๋๋ก embedding์ ์์ถ์ด ๊ฐ๋ฅํฉ๋๋ค. n_clusters์ ์์ถํ๊ณ ์ ํ๋ ํฌ๊ธฐ๋ฅผ ์ง์ ํ๊ณ batch_size์ 256 * CPU์ ์ฝ์ด ์๋ฅผ ์ง์ ํจ์ผ๋ก์จ CPU ๋ณ๋ ฌํ์ ํํ์ ์ถฉ๋ถํ ์ป์ ์ ์์ต๋๋ค.
import multiprocessing
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=10000, batch_size=256 * multiprocessing.cpu_count(), init="random")
kmeans.fit(big_npy)
sample_npy = kmeans.cluster_centers_