Coverage for api/utils/data_grabber.py: 15%
258 statements
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-10 03:02 +0300
« prev ^ index » next coverage.py v7.6.2, created at 2024-10-10 03:02 +0300
1import datetime
2import inspect
3import json
4from json import JSONDecodeError
5from typing import Any, Awaitable, Callable, Optional, cast
6from uuid import UUID
8import pymongo
9from bson import ObjectId
10from database.text_search.mongo_search import update_query_by_search
11from dbcc import MongoTableEngine
12from errors import log_error, log_warning
13from exceptions import NotAcceptableHTTPError, NotFoundHTTPError
14from fastapi import HTTPException, Query
15from handlers.authorization.assigned_filters import (
16 ensure_default_assigned,
17 ensure_default_clients,
18 ensure_default_clients_assignment,
19 make_filter_from_assignment,
20)
21from handlers.authorization.check_role import has_role
22from pydantic import BaseModel
23from pydantic_core import ValidationError
24from sotrans_models.models.orders.order import OrderDBModel
25from sotrans_models.models.responses import GenericGetListResponse
26from sotrans_models.models.roles import SotransRole
27from sotrans_models.models.users import SotransOIDCUserModel
28from starlette import status
29from utils.dt_utils import parse_datetime_filter
32class BaseGetListQueryParams:
33 def __init__(
34 self,
35 skip: int = Query(0, description=""),
36 limit: int = Query(200, description=""),
37 sort: str = Query(None, description=""),
38 where: str = Query(None, description=""),
39 search: str = Query(None, description=""),
40 projection: str = Query(None, description=""),
41 assignment: str = Query(None, description=""),
42 ):
43 self.assignment = assignment
44 self.where = where
45 self.sort = sort
46 self.search = search
47 self.projection = projection
48 self.skip = skip
49 self.limit = limit
52class BaseGetOneQueryParams:
53 def __init__(
54 self,
55 projection: str = Query(None, description=""),
56 ):
57 self.projection = projection
60def get_field_model(sample: Any, k: str, model_out: type[BaseModel]) -> Any:
61 # Building `dict` -> `model` -> `dict` from value and param path
62 full_path = k.split(".")
63 modeling = {full_path.pop(): sample}
64 while full_path:
65 next_key = full_path.pop()
66 if next_key.isdigit():
67 modeling = [modeling] # type:ignore[assignment]
68 else:
69 modeling = {next_key: modeling}
70 model = cast(type[BaseModel], model_out)
71 try:
72 modeled = model(**modeling)
73 except ValidationError:
74 error_str = "Неверные параметры запроса"
75 log_error(error_str)
76 raise HTTPException(status.HTTP_400_BAD_REQUEST, error_str)
77 field_model = modeled.model_dump(exclude_none=True)
78 while isinstance(fmv := list(field_model.values())[0], dict):
79 field_model = fmv
80 return fmv
83class MongoDataGrabber:
84 model_out: type
85 collection: MongoTableEngine
86 settings: dict | None
87 """
88 {
89 'where': {
90 'fields':{
91 'secret': 0,
92 'token': 0, # This will allow search on every field but mentioned ones
93 },
94 'operators': {
95 'in': 1,
96 'all': 1 # This will allow search using only mentioned operators
97 }
98 },
99 'sort': {
100 'default': [()],
101 'max_fields': 2,
102 'fields': {
103 'name': 1,
104 'date': 1, # This will allow sort by only mentioned fields
105 }
106 },
107 'projection': {
108 'default': {},
109 'policy': 'force'
110 }
111 }
112 """
114 def __init__(
115 self,
116 model_out: type,
117 collection: MongoTableEngine,
118 settings: dict | None = None,
119 ):
120 self.model_out = model_out
121 self.collection = collection
122 self.settings = settings
124 @staticmethod
125 def convert_values(
126 element: Any, annotation: type | Callable, in_op: bool
127 ) -> list | Any:
128 if annotation is datetime.datetime:
129 annotation = parse_datetime_filter
130 if in_op:
131 element["$in"] = [annotation(v) for v in element["$in"]]
132 return element
133 return annotation(element)
135 @staticmethod
136 def where_conversion(pattern: dict | None, model_out: type[BaseModel]):
137 if not pattern:
138 return
140 for k in pattern:
141 is_in = False
142 if k == "$or" and isinstance(ors := pattern[k], list):
143 for or_ in ors:
144 MongoDataGrabber.where_conversion(or_, model_out)
145 continue
146 compare_ops = ("$gte", "$gt", "$eq", "$lt", "$lte", "$ne")
147 compared = False
148 if isinstance(sample := pattern[k], dict) and (
149 "$in" in sample or "$nin" in sample
150 ):
151 op = "$in" if "$in" in sample else "$nin"
152 if not isinstance(sample[op], list):
153 raise HTTPException(
154 status.HTTP_400_BAD_REQUEST,
155 f"Неверный параметр {op}.",
156 )
157 if not sample[op]:
158 continue
159 is_in = True
160 sample = sample[op][0]
161 elif isinstance(sample, dict) and any(
162 comp_op in sample for comp_op in compare_ops
163 ):
164 sample = list(sample.values())[0]
165 compared = True
167 # digging deeper for the only key
168 try:
169 fmv = get_field_model(sample, k, model_out)
170 except IndexError:
171 continue
172 if compared and type(fmv) in (ObjectId, UUID, datetime.datetime):
173 for c in pattern[k]:
174 pattern[k][c] = MongoDataGrabber.convert_values(
175 pattern[k][c], type(fmv), False
176 )
178 elif type(fmv) in (ObjectId, UUID, datetime.datetime):
179 pattern[k] = MongoDataGrabber.convert_values(
180 pattern[k], type(fmv), is_in
181 )
183 async def get_list(
184 self,
185 params: BaseGetListQueryParams,
186 user: SotransOIDCUserModel | None = None,
187 *,
188 clients_assignment_default: bool = False,
189 ):
190 where = self.parse_where(params.where)
191 self.where_conversion(where, self.model_out)
192 search_query: dict = {}
193 if params.search:
194 await update_query_by_search(
195 self.collection, params.search, search_query
196 )
197 if where:
198 where.update(search_query)
199 else:
200 where = search_query
201 if user and not has_role(user, SotransRole.drivers_bot):
202 up_filter = await self.parse_assignment(
203 params.assignment,
204 user,
205 clients_assignment_default=clients_assignment_default,
206 )
207 where.update(up_filter)
208 sort = self.parse_sort(params.sort)
209 if sort is None:
210 sort = [("_id", pymongo.DESCENDING)]
211 projection = self.parse_projection(params.projection)
212 limit = params.limit if params.limit <= 250 else 250
213 output = []
214 async for entity in self.collection.find_batch_raw(
215 pattern=where,
216 skip=params.skip,
217 sort=sort,
218 projection=projection,
219 limit=limit,
220 ):
221 try:
222 entity = self.model_out(**entity)
223 except ValidationError:
224 log_error(
225 f"collection: {self.collection.collection_name}, id: {entity.get('id')}"
226 )
227 continue
228 except Exception:
229 log_error()
230 continue
231 # TODO: Make custom BaseModel and implement parameters config_exclude_none and config_exclude_unset
232 # and overwrite model_dump function so that implemented parameters are more important than function.
233 # That way using fastapi there will be no double conversion between dict and model
234 # output.append(entity.default_model_dump(exclude_none=params.exclude_none))
235 output.append(entity.model_dump(format_ids=False))
236 total = await self.collection.count(where or {})
237 return GenericGetListResponse[self.model_out](
238 items=output, total=total
239 )
241 async def get_one(
242 self,
243 id: str | ObjectId,
244 params: BaseGetOneQueryParams = BaseGetOneQueryParams(),
245 ):
246 data = await self.collection.find_single(
247 "id",
248 ObjectId(id),
249 projection=self.parse_projection(params.projection),
250 )
251 if data:
252 entity = self.model_out(**data)
253 else:
254 raise HTTPException(
255 status_code=404, detail=f"Сущность с ID {id} не найдена."
256 )
257 return entity.model_dump(format_ids=False)
259 async def get_one_by_id_with_pattern(
260 self,
261 id: ObjectId | str,
262 params: BaseGetOneQueryParams = BaseGetOneQueryParams(),
263 pattern: dict | None = None,
264 processors: list[Callable[[type[BaseModel]], Awaitable[None] | None]]
265 | None = None,
266 ) -> type[BaseModel]:
267 if processors is None:
268 processors = []
269 if pattern is None:
270 pattern = {}
271 id_key = "id" if self.model_out is OrderDBModel else "_id"
272 found = await self.collection.collection.find_one(
273 {id_key: ObjectId(id)} | pattern, projection=params.projection
274 )
275 if found:
276 try:
277 entity = self.model_out(**found)
278 except ValidationError:
279 log_error(f"{found} missed schema.")
280 raise NotAcceptableHTTPError(
281 "Сущность не соответствует схеме."
282 )
283 else:
284 raise NotFoundHTTPError
285 for processor in processors:
286 if inspect.iscoroutinefunction(processor):
287 await processor(entity)
288 else:
289 processor(entity)
290 return entity
292 @staticmethod
293 def parse_where(query_input: str) -> Optional[dict]:
294 if not query_input:
295 return None
296 try:
297 return json.loads(query_input)
298 except JSONDecodeError:
299 text = f"Неверный where запрос: {query_input}. \
300 Синтаксис должен совпадать с параметром filter функции find в модуле pymongo."
301 log_warning(text)
302 raise HTTPException(status_code=400, detail=text)
304 @staticmethod
305 async def parse_assignment(
306 query_input: str | None,
307 user: SotransOIDCUserModel,
308 *,
309 clients_assignment_default: bool = False,
310 ) -> dict[str, Any]:
311 text = f"Неверный формат фильтра assignment: {query_input}"
312 if query_input:
313 try:
314 query = json.loads(query_input)
315 except JSONDecodeError:
316 # cause of IDE highlight
317 log_warning(text)
318 raise HTTPException(status.HTTP_400_BAD_REQUEST, text)
319 if not isinstance(query, list):
320 log_warning(text)
321 raise HTTPException(status.HTTP_400_BAD_REQUEST, text)
322 else:
323 query = []
324 is_carrier = (
325 has_role(user, SotransRole.carrier_logistician)
326 or has_role(user, SotransRole.carrier_director)
327 ) and not has_role(user, SotransRole.company_logistician)
328 if not (has_role(user, SotransRole.company_director) or is_carrier):
329 if clients_assignment_default:
330 ensure_default_clients(query)
331 else:
332 if has_role(user, SotransRole.company_manager):
333 ensure_default_clients_assignment(query)
334 else:
335 ensure_default_assigned(query)
336 try:
337 return await make_filter_from_assignment(query, user)
338 except KeyError:
339 log_warning(text)
340 raise HTTPException(status.HTTP_400_BAD_REQUEST, text)
342 @staticmethod
343 def parse_projection(query_input: str | None) -> Optional[dict]:
344 if not query_input:
345 return None
346 text = f"Неверный запрос projection: {query_input}. \
347 Синтаксис должен совпадать с параметром projection функции find модуля pymongo."
349 try:
350 projection = json.loads(query_input)
351 if not isinstance(projection, dict):
352 raise HTTPException(status_code=400, detail=text)
353 vals = projection.values()
354 if 1 in vals and 0 in vals:
355 raise HTTPException(
356 status_code=400,
357 detail="Запрос может быть либо включающим, либо исключающим",
358 )
359 return projection
360 except JSONDecodeError:
361 log_warning(text)
362 raise HTTPException(status_code=400, detail=text)
364 @staticmethod
365 def parse_sort(query_input: str | None) -> Optional[list]:
366 if not query_input:
367 return None
368 output = []
369 query_input = query_input[1:-1]
370 sort_tuples_str = query_input.split("),")
371 for sort_tuple in sort_tuples_str:
372 if not sort_tuple:
373 continue
374 sort_tuple = sort_tuple.strip().replace("(", "").replace(")", "")
375 sort_parameters = sort_tuple.split(",")
376 try:
377 first_param = (
378 sort_parameters[0]
379 .strip()
380 .replace('"', "")
381 .replace("'", "")
382 )
383 second_param = int(sort_parameters[1].strip())
384 except (ValueError, IndexError):
385 text = f"Неверный параметр sort: {query_input}. \
386 Синтаксис должен совпадать с sort функции find модуля pymongo."
387 log_warning(text)
388 raise HTTPException(status_code=400, detail=text)
389 output.append(
390 (
391 first_param,
392 second_param,
393 )
394 )
395 if output:
396 return output
397 return None
400def adjust_search_query(
401 initial_query: str,
402 filtered_query_key: str,
403 allowed_values: set | None = None,
404) -> str:
405 """
406 Adjusts the search query by removing the forbidden query keys and then setting it to the required query.
408 :param initial_query: The initial query in str from which forbidden keys need to be removed.
409 :param filtered_query_key: The forbidden key query need to be removed.
410 :param allowed_values: Optimal structure for filtering by key, jsonable;
411 allowed_values = None allows to drop the key search for user,
412 forcing user to retrieve unfiltered data
413 :return: The adjusted search query in string format.
414 """
415 # Parse the forbidden query
416 search_query = MongoDataGrabber.parse_where(initial_query) or {}
417 # If no key, then update query with new key
418 if filtered_query_key not in search_query:
419 if allowed_values is not None:
420 search_query = {
421 **search_query,
422 **{filtered_query_key: {"$in": list(allowed_values)}},
423 }
424 return json.dumps(search_query)
425 if None is allowed_values:
426 del search_query[
427 filtered_query_key
428 ] # so None as allowed clears the query
429 else:
430 if (
431 isinstance(search_query[filtered_query_key], dict)
432 and "$in" in search_query[filtered_query_key]
433 ):
434 search_query[filtered_query_key]["$in"] = [
435 value
436 for value in search_query[filtered_query_key]["$in"]
437 if value in allowed_values
438 ]
439 elif search_query[filtered_query_key] in allowed_values:
440 pass
441 else:
442 raise HTTPException(
443 status.HTTP_404_NOT_FOUND,
444 "Нет сущностей подходящих по параметрам запроса.",
445 )
447 return json.dumps(search_query)
450def update_search_query(user_query: str, server_query: dict) -> str:
451 user_pattern = MongoDataGrabber.parse_where(user_query) or {}
452 for k in server_query:
453 if k in user_pattern:
454 raise HTTPException(
455 status.HTTP_400_BAD_REQUEST, f"Внутренний параметр {k}"
456 )
457 if k == "$or":
458 for q in server_query[k]:
459 for q_k in q:
460 if q_k in user_pattern:
461 raise HTTPException(
462 status.HTTP_400_BAD_REQUEST,
463 f"Внутренний параметр {q_k}",
464 )
465 user_pattern.update(server_query)
466 return json.dumps(user_pattern)