4 min read

DuckDB でハイブリッド検索

DuckDB を利用してベクトル検索と日本語全文検索の両方を同時に利用できます。さらにこれらの結果をマージして Reranking を行うことでハイブリッド検索をサクサクっと実現する事が​

Reranker

どうやらベクトル検索した結果と日本語全文検索した結果をマージして、クエリーとマージ結果を再度ランキング付けする仕組みのようです。

ここでは参考にした記事を共有する程度にしておきます。

今回は Reranker に hotchpotch/japanese-reranker-cross-encoder-large-v1 を利用しました。


以下は参考コードです。

[project]
name = "duckdb-hybrid-search"
version = "0.1.0"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "duckdb>=1.2.2",
    "lindera-py>=0.41.0",
    "numpy>=2.2.5",
    "sentence-transformers>=4.1.0",
    "sentencepiece>=0.2.0",
    "torch>=2.7.0",
    "transformers>=4.51.3",
]
# SPDX-License-Identifier: Apache-2.0
import duckdb
import torch
from lindera_py import Segmenter, Tokenizer, load_dictionary
from sentence_transformers import CrossEncoder
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

v_tokenizer = AutoTokenizer.from_pretrained(
    "pfnet/plamo-embedding-1b", trust_remote_code=True
)
v_model = AutoModel.from_pretrained("pfnet/plamo-embedding-1b", trust_remote_code=True)
v_model = v_model.to(device)

dictionary = load_dictionary("ipadic")
segmenter = Segmenter("normal", dictionary)
tokenizer = Tokenizer(segmenter)

r_model = CrossEncoder(
    "hotchpotch/japanese-bge-reranker-v2-m3-v1", max_length=512, device=device
)


def ja_tokens(text: str) -> str:
    return " ".join(t.text for t in tokenizer.tokenize(text))


def main():
    conn = duckdb.connect()
    conn.install_extension("vss")
    conn.load_extension("vss")
    conn.install_extension("fts")
    conn.load_extension("fts")

    conn.sql("CREATE SEQUENCE IF NOT EXISTS id_sequence START 1;")
    conn.sql("""
        CREATE TABLE sora_doc (
            id INTEGER DEFAULT nextval('id_sequence') PRIMARY KEY,
            content VARCHAR,  
            content_v FLOAT[2048],
            content_t VARCHAR
        );
    """)

    # https://sora-doc.shiguredo.jp/ より引用
    docs = [
        "例えば 3 ノードのクラスターがある場合、 すでに接続しているクライアントがいるノードとは異なるノードにクライアントが接続した場合、Sora はその異なるノードにすでに接続しているクライアントの音声や映像、データをリレーします。",
        "StartRecording API やセッションウェブフックの戻り値で指定できる録画メタデータについてはセンシティブなデータとして扱っていません。これは録画ファイル出力時の録画メタデータファイルに含まれ、映像合成時に利用する事を想定しているためです。",
        "WebSocket は TCP ベースのため Head of Line Blocking が存在し、不安定な回線などでパケットが詰まってしまうことがあります。 DataChannel は WebSocket とは異なり、パケットを並列でやりとりできるため、不安定な回線などでもパケットが詰まることが少なくなります。 シグナリングを WebSocket 経由から DataChannel 経由へ切り替える機能を提供することでより安定した接続が維持できます。",
    ]

    with torch.inference_mode():
        for doc, doc_embedding in zip(docs, v_model.encode_document(docs, v_tokenizer)):
            conn.execute(
                "INSERT INTO sora_doc (content, content_v, content_t) VALUES (?, ?, ?)",
                [
                    doc,
                    doc_embedding.cpu().squeeze().numpy().tolist(),
                    ja_tokens(doc),
                ],
            )

    conn.sql("""
        PRAGMA create_fts_index(
            'sora_doc',
            'id',
            'content_t',

            stemmer = 'none',
            stopwords = 'none',
            ignore = '',
            lower = false,
            strip_accents = false
        );
    """)

    hybrid_search(conn, "センシティブデータについて教えてください")


def fts_search(conn, query):
    q_tokens = ja_tokens(query)
    rows = conn.sql(f"""
        SELECT id, fts_main_sora_doc.match_bm25(id, '{q_tokens}') AS score, content
        FROM sora_doc
        WHERE score IS NOT NULL
        ORDER BY score DESC
    """).fetchall()

    return rows


def vss_search(conn, query):
    with torch.inference_mode():
        query_embedding = v_model.encode_query(query, v_tokenizer)
        rows = conn.sql(
            """
            SELECT id, array_cosine_distance(content_v, ?::FLOAT[2048]) as distance, content
            FROM sora_doc
            ORDER BY distance ASC
            """,
            params=[query_embedding.cpu().squeeze().numpy().tolist()],
        ).fetchall()

        return rows


def reranking(query, vss_rows, fts_rows):
    # 凄く雑なマージ、あとから content から id をとれるようにしてる
    passages = {}
    for row in vss_rows:
        id, _, content = row
        passages[content] = id
    for row in fts_rows:
        id, _, content = row
        passages[content] = id

    contents = list(passages.keys())
    # Reranker
    scores = r_model.predict([(query, content) for content in contents])
    # スコア高い順にソートするタイミングで id と content を score に紐づける
    return sorted(
        [
            (passages[content], score, content)
            for content, score in zip(contents, scores)
        ],
        key=lambda x: x[1],
        reverse=True,
    )


def hybrid_search(conn, query):
    print("query:", query)

    # FTS
    print("--- DuckDB-FTS + Lindera ---")
    fts_rows = fts_search(conn, query)
    for id, score, content in fts_rows:
        print(f"ID: {id}, Score: {score:.4f}, Content: {content}")

    # VSS
    print("--- DuckDB-VSS + PLaMo ---")
    vss_rows = vss_search(conn, query)
    for id, score, content in vss_rows:
        print(f"ID: {id}, Score: {score:.4f}, Content: {content}")

    # Reranking
    print("--- Reranking ---")
    reranking_rows = reranking(query, vss_rows, fts_rows)
    for id, score, content in reranking_rows:
        print(f"ID: {id}, Score: {score:.4f}, Content: {content}")


if __name__ == "__main__":
    main()
    # query: センシティブデータについて教えてください
    # --- DuckDB-FTS + Lindera ---
    # ID: 2, Score: 4.5369, Content: StartRecording API やセッションウェブフックの戻り値で指定できる録画メタデータについてはセンシティブなデータとして扱っていません。これは録画ファイル出力時の録画メタデータファイルに含まれ、映像合成時に利用する事を想定しているためです。
    # ID: 1, Score: 1.7540, Content: 例えば 3 ノードのクラスターがある場合、 すでに接続しているクライアントがいるノードとは異なるノードにクライアントが接続した場合、Sora はその異なるノードにすでに接続しているクライアントの音声や映像、データをリレーします。
    # ID: 3, Score: 0.8607, Content: WebSocket は TCP ベースのため Head of Line Blocking が存在し、不安定な回線などでパケットが詰まってしまうことがあります。 DataChannel は WebSocket とは異なり、パケットを並列でやりとりできるため、不安定な回線などでもパケットが詰まることが少なくなります。 シグナリングを WebSocket 経由から DataChannel 経由へ切り替える機能を提供することでより安定した接続が維持できます。
    # --- DuckDB-VSS + PLaMo ---
    # ID: 2, Score: 0.3009, Content: StartRecording API やセッションウェブフックの戻り値で指定できる録画メタデータについてはセンシティブなデータとして扱っていません。これは録画ファイル出力時の録画メタデータファイルに含まれ、映像合成時に利用する事を想定しているためです。
    # ID: 1, Score: 0.3974, Content: 例えば 3 ノードのクラスターがある場合、 すでに接続しているクライアントがいるノードとは異なるノードにクライアントが接続した場合、Sora はその異なるノードにすでに接続しているクライアントの音声や映像、データをリレーします。
    # ID: 3, Score: 0.4412, Content: WebSocket は TCP ベースのため Head of Line Blocking が存在し、不安定な回線などでパケットが詰まってしまうことがあります。 DataChannel は WebSocket とは異なり、パケットを並列でやりとりできるため、不安定な回線などでもパケットが詰まることが少なくなります。 シグナリングを WebSocket 経由から DataChannel 経由へ切り替える機能を提供することでより安定した接続が維持できます。
    # --- Reranking ---
    # ID: 2, Score: 0.3005, Content: StartRecording API やセッションウェブフックの戻り値で指定できる録画メタデータについてはセンシティブなデータとして扱っていません。これは録画ファイル出力時の録画メタデータファイルに含まれ、映像合成時に利用する事を想定しているためです。
    # ID: 1, Score: 0.0002, Content: 例えば 3 ノードのクラスターがある場合、 すでに接続しているクライアントがいるノードとは異なるノードにクライアントが接続した場合、Sora はその異なるノードにすでに接続しているクライアントの音声や映像、データをリレーします。
    # ID: 3, Score: 0.0002, Content: WebSocket は TCP ベースのため Head of Line Blocking が存在し、不安定な回線などでパケットが詰まってしまうことがあります。 DataChannel は WebSocket とは異なり、パケットを並列でやりとりできるため、不安定な回線などでもパケットが詰まることが少なくなります。 シグナリングを WebSocket 経由から DataChannel 経由へ切り替える機能を提供することでより安定した接続が維持できます。

ベクトル検索と日本語全文検索をマージして Reranking を利用したハイブリッド検索を DuckDB でサクサクっと実現することができるので、是非試してみてください。