Coverage for api/services/opensearch_client.py: 46%
48 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-10 03:02 +0300
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-10 03:02 +0300
1import functools
3import config
4import indexes
5import mongodb
6from bson import ObjectId
7from dbcc import MongoTableEngine
8from errors import log_error
9from fastapi import HTTPException
10from opensearchpy import AsyncOpenSearch, NotFoundError
11from starlette import status
12from utils.helper import validate_search_q
14client = AsyncOpenSearch(
15 hosts=[
16 {
17 "host": config.OpenSearchConfig.HOST,
18 "port": config.OpenSearchConfig.PORT,
19 }
20 ],
21 http_auth=(
22 config.OpenSearchConfig.LOGIN,
23 config.OpenSearchConfig.PASSWORD,
24 ),
25 use_ssl=False,
26 ssl_show_warn=False,
27)
30class OpenSearchAdapter:
31 def __init__(self, os_client: AsyncOpenSearch):
32 self.client = os_client
34 async def search_fts(self, index: str, search_q: str) -> list[ObjectId]:
35 search_q = validate_search_q(search_q)
36 try:
37 results = await self.client.search(
38 body={
39 "query": {
40 "multi_match": {
41 "query": search_q,
42 "fields": ["*"],
43 "type": "phrase_prefix",
44 }
45 }
46 },
47 index=index,
48 )
49 except NotFoundError:
50 # log_error()
51 return []
52 hits = results["hits"]["hits"]
54 return [ObjectId(e["_id"]) for e in hits]
56 async def vector_search(
57 self, index: str, search_query: str, k: int = 1
58 ) -> list[ObjectId]:
59 return []
60 # vector_model_id = get_opensearch_vector_model_id()
61 # if vector_model_id is None:
62 # raise HTTPException(
63 # status.HTTP_503_SERVICE_UNAVAILABLE,
64 # "Нет модели для векторного поиска",
65 # )
66 # search_query = validate_search_q(search_query)
67 # try:
68 # response = await self.client.search(
69 # body={
70 # # "size": k,
71 # "_source": {"excludes": ["embedding"]},
72 # "query": {
73 # "neural": {
74 # "embedding": {
75 # "query_text": search_query,
76 # "model_id": vector_model_id,
77 # "k": k,
78 # }
79 # }
80 # },
81 # },
82 # index=index,
83 # )
84 # hits = response["hits"]["hits"]
85 # return [ObjectId(e["_id"]) for e in hits]
86 #
87 # except NotFoundError:
88 # # log_error(f"search_q: [{search_query}], index: [{index}]")
89 # return []
92class MockAdapter:
93 async def search_fts(self, index: str, search_q: str) -> list[ObjectId]:
94 col_name = index.split(".")[1]
95 ids = await mongodb.db[col_name].find_batch({}, projection={"_id": 1})
96 return [o["_id"] for o in ids]
98 async def vector_search(
99 self, index: str, search_query: str, k: int = 1
100 ) -> list[ObjectId]:
101 col_name = index.split(".")[1]
102 ids = await mongodb.db[col_name].find_batch({}, projection={"_id": 1})
103 return [o["_id"] for o in ids]
106if config.TESTING: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true
107 opensearch_adapter = MockAdapter()
108else:
109 opensearch_adapter = OpenSearchAdapter(client) # type: ignore[assignment]
112async def get_ids_query(
113 collection: MongoTableEngine, search_q: str, k: int = 10
114):
115 if not search_q:
116 return {}
117 ids_in = await opensearch_adapter.vector_search(
118 f"{config.MONGO_DB_NAME}.{collection.collection_name}", search_q, k=k
119 )
120 return {"_id": {"$in": ids_in}}
123@functools.cache
124def get_opensearch_vector_model_id() -> str | None:
125 os_index = indexes.db["os_index"].find_one({})
126 if os_index is None:
127 return None
128 return os_index.get("model_id")