Source code for restapi.models.fields

import re
from collections.abc import Mapping
from typing import Any, Optional, Union

import orjson
from marshmallow import ValidationError, types
from marshmallow.utils import missing as missing_
from webargs import fields

from restapi.config import TESTING
from restapi.utilities.logs import log

####################################################################################
#    1) Override some types with custom implementation to extend functionalities
####################################################################################


[docs] class String(fields.String): def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: Any, ) -> Any: value = super()._deserialize(value, attr, data, **kwargs) if value and isinstance(value, str): value = value.strip() return value
[docs] class List(fields.List):
[docs] def __init__( self, cls_or_instance: Union[fields.Field, type], *, unique: bool = False, min_items: int = 0, **kwargs: Any, ) -> None: super().__init__(cls_or_instance, **kwargs) self.unique = unique self.min_items = min_items
def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: Any, ) -> Any: # this is the case when requests (or pytest) send some json-dumped lists # for example for a multi-value select if isinstance(value, str): try: value = orjson.loads(value) except Exception: pass value = super()._deserialize(value, attr, data, **kwargs) if not isinstance(value, list): # pragma: no cover raise ValidationError("Invalid type") if self.unique: value = list(set(value)) if len(value) < self.min_items: raise ValidationError( f"Expected at least {self.min_items} items, received {len(value)}" ) return value
[docs] class Nested(fields.Nested):
[docs] def __init__( self, # nested: Union[SchemaABC, type, str, Callable[[], SchemaABC]], # the above type is from marshmallow, but it fails with dict[str, Any] (imc) nested: Any, *, default: Any = missing_, only: Optional[types.StrSequenceOrSet] = None, exclude: types.StrSequenceOrSet = (), many: bool = False, unknown: Optional[str] = None, **kwargs: Any, ): super().__init__( nested, default=default, only=only, exclude=exclude, many=many, unknown=unknown, **kwargs, )
def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], partial: Optional[Union[bool, tuple[str]]] = None, **kwargs: Any, ) -> Any: # this is the case when requests (or pytest) send some json-dumped object if isinstance(value, str): try: value = orjson.loads(value) except Exception as e: # pragma: no cover log.warning(e) # This is because Nested is not typed on marshmallow return super()._deserialize(value, attr, data, **kwargs) # type: ignore
# DelimitedList is child of List as defined in: # https://github.com/marshmallow-code/webargs/blob/dev/src/webargs/fields.py
[docs] class DelimitedList(fields.DelimitedList, List):
[docs] def __init__( self, cls_or_instance: Union[fields.Field, type], *, delimiter: Optional[str] = None, unique: bool = False, **kwargs: Any, ) -> None: super().__init__(cls_or_instance, delimiter=delimiter, **kwargs) # Note: Can't use self.unique otherwise the elements will be silently cleaned # by the custom List deserializer # self.unique = unique self.no_duplicates = unique
def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: Any, ) -> Any: if not value: return value # This is because List is not typed on marshmallow values = super()._deserialize(value, attr, data, **kwargs) # type: ignore if self.no_duplicates and len(values) != len(set(values)): raise ValidationError("Input list contains duplicates") return values
#################################################################################### # 2) Add some additional custom fields #################################################################################### # Should be extended to automatically include the choices as validation field to be # Converted as a select on frontend
[docs] class Neo4jChoice(fields.Field): """Field that serializes from a neo4j choice""" # choice_model is the same used in neo4j model as choices=
[docs] def __init__(self, choices_model: Any, **kwargs: Any) -> None: super().__init__(**kwargs) if isinstance(choices_model, dict): self.choices_dict = choices_model else: # convert the tuple of tuple into as a dictionary for convenience self.choices_dict = {} for k, v in choices_model: self.choices_dict.setdefault(k, v)
def _serialize( self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any ) -> Any: return { "key": value, # the value correspondance from choices_dict or value as default "description": self.choices_dict.get(value, value), } def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: Any, ) -> Any: return value
[docs] class Neo4jRelationshipToMany(Nested): # value: StructuredRel # Signature of "_serialize" incompatible with supertype "Nested" # This is because Nested is not typed on marshmallow def _serialize( # type: ignore self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any ) -> Any: self.many = True # This is because Nested is not typed on marshmallow return super()._serialize(value.all(), attr, obj, **kwargs) # type: ignore
[docs] class Neo4jRelationshipToSingle(Nested): # value: StructuredRel # Signature of "_serialize" incompatible with supertype "Nested" # This is because Nested is not typed on marshmallow def _serialize( # type: ignore self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any, ) -> Any: self.many = False self.schema.many = False # This is because Nested is not typed on marshmallow return super()._serialize(value.single(), attr, obj, **kwargs) # type: ignore
[docs] class Neo4jRelationshipToCount(fields.Int): # value: StructuredRel def _serialize( self, value: Any, attr: Optional[str], obj: Any, **kwargs: Any ) -> Any: return self._format_num(len(value))
[docs] class TOTP(String): def _deserialize( self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: Any, ) -> Any: value = super()._deserialize(value, attr, data, **kwargs) if not re.match(r"^[0-9]{6}$", value): if TESTING: log.error("Invalid TOTP format: {}", value) raise ValidationError("Invalid TOTP format") return value
#################################################################################### # 3) Create some aliases #################################################################################### Field = fields.Field Boolean = fields.Boolean DateTime = fields.DateTime Dict = fields.Dict Number = fields.Number AwareDateTime = fields.AwareDateTime Date = fields.Date Decimal = fields.Decimal Email = fields.Email Integer = fields.Integer Float = fields.Float NaiveDateTime = fields.NaiveDateTime Url = fields.Url UUID = fields.UUID # Aliases, as defined in # https://github.com/marshmallow-code/marshmallow/blob/dev/src/marshmallow/fields.py URL = Url Str = String Bool = Boolean Int = Integer