Coverage for api/services/opensearch_client.py: 46%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.2, created at 2024-10-10 03:02 +0300

1import functools 

2 

3import config 

4import indexes 

5import mongodb 

6from bson import ObjectId 

7from dbcc import MongoTableEngine 

8from errors import log_error 

9from fastapi import HTTPException 

10from opensearchpy import AsyncOpenSearch, NotFoundError 

11from starlette import status 

12from utils.helper import validate_search_q 

13 

14client = AsyncOpenSearch( 

15 hosts=[ 

16 { 

17 "host": config.OpenSearchConfig.HOST, 

18 "port": config.OpenSearchConfig.PORT, 

19 } 

20 ], 

21 http_auth=( 

22 config.OpenSearchConfig.LOGIN, 

23 config.OpenSearchConfig.PASSWORD, 

24 ), 

25 use_ssl=False, 

26 ssl_show_warn=False, 

27) 

28 

29 

30class OpenSearchAdapter: 

31 def __init__(self, os_client: AsyncOpenSearch): 

32 self.client = os_client 

33 

34 async def search_fts(self, index: str, search_q: str) -> list[ObjectId]: 

35 search_q = validate_search_q(search_q) 

36 try: 

37 results = await self.client.search( 

38 body={ 

39 "query": { 

40 "multi_match": { 

41 "query": search_q, 

42 "fields": ["*"], 

43 "type": "phrase_prefix", 

44 } 

45 } 

46 }, 

47 index=index, 

48 ) 

49 except NotFoundError: 

50 # log_error() 

51 return [] 

52 hits = results["hits"]["hits"] 

53 

54 return [ObjectId(e["_id"]) for e in hits] 

55 

56 async def vector_search( 

57 self, index: str, search_query: str, k: int = 1 

58 ) -> list[ObjectId]: 

59 return [] 

60 # vector_model_id = get_opensearch_vector_model_id() 

61 # if vector_model_id is None: 

62 # raise HTTPException( 

63 # status.HTTP_503_SERVICE_UNAVAILABLE, 

64 # "Нет модели для векторного поиска", 

65 # ) 

66 # search_query = validate_search_q(search_query) 

67 # try: 

68 # response = await self.client.search( 

69 # body={ 

70 # # "size": k, 

71 # "_source": {"excludes": ["embedding"]}, 

72 # "query": { 

73 # "neural": { 

74 # "embedding": { 

75 # "query_text": search_query, 

76 # "model_id": vector_model_id, 

77 # "k": k, 

78 # } 

79 # } 

80 # }, 

81 # }, 

82 # index=index, 

83 # ) 

84 # hits = response["hits"]["hits"] 

85 # return [ObjectId(e["_id"]) for e in hits] 

86 # 

87 # except NotFoundError: 

88 # # log_error(f"search_q: [{search_query}], index: [{index}]") 

89 # return [] 

90 

91 

92class MockAdapter: 

93 async def search_fts(self, index: str, search_q: str) -> list[ObjectId]: 

94 col_name = index.split(".")[1] 

95 ids = await mongodb.db[col_name].find_batch({}, projection={"_id": 1}) 

96 return [o["_id"] for o in ids] 

97 

98 async def vector_search( 

99 self, index: str, search_query: str, k: int = 1 

100 ) -> list[ObjectId]: 

101 col_name = index.split(".")[1] 

102 ids = await mongodb.db[col_name].find_batch({}, projection={"_id": 1}) 

103 return [o["_id"] for o in ids] 

104 

105 

106if config.TESTING: 106 ↛ 109line 106 didn't jump to line 109 because the condition on line 106 was always true

107 opensearch_adapter = MockAdapter() 

108else: 

109 opensearch_adapter = OpenSearchAdapter(client) # type: ignore[assignment] 

110 

111 

112async def get_ids_query( 

113 collection: MongoTableEngine, search_q: str, k: int = 10 

114): 

115 if not search_q: 

116 return {} 

117 ids_in = await opensearch_adapter.vector_search( 

118 f"{config.MONGO_DB_NAME}.{collection.collection_name}", search_q, k=k 

119 ) 

120 return {"_id": {"$in": ids_in}} 

121 

122 

123@functools.cache 

124def get_opensearch_vector_model_id() -> str | None: 

125 os_index = indexes.db["os_index"].find_one({}) 

126 if os_index is None: 

127 return None 

128 return os_index.get("model_id")