Source code for backend.api.extension

import enum

from flask import Blueprint, current_app, make_response
from flask.json import dumps, JSONEncoder as BaseJSONEncoder
from flask.views import MethodViewType
from flask_restful import Api as BaseApi
from flask_sqlalchemy.model import camel_to_snake_case, Model
from marshmallow import MarshalResult
from werkzeug.wrappers import Response

from backend.extensions import db
from backend.utils import was_decorated_without_parenthesis

from .constants import CREATE, DELETE, GET, LIST, PATCH, PUT
from .model_resource import ModelResource
from .utils import get_last_param_name


def _get_model_resource_args(args):
    bp, model, urls = None, args[0], args[1:]
    if isinstance(args[0], Blueprint):
        bp, model, urls = args[0], args[1], args[2:]
    if not issubclass(model, db.Model):
        raise NotImplementedError(
            f"The {'second' if bp else 'first'} argument to Api.model_resource"
            ' must be a database model class')
    if not urls:
        raise NotImplementedError(
            'Api.model_resource requires at least one url argument.')
    return bp, model, urls


[docs]class Api(BaseApi): """Extends :class:`flask_restful.Api` to support integration with Flask-Marshmallow serializers, along with a few other minor enhancements: - can register individual view functions ala blueprints, via @api.route() - supports using flask.jsonify() in resource methods """ def __init__(self, name, app=None, prefix='', default_mediatype='application/json', decorators=None, catch_all_404s=False, serve_challenge_on_401=False, url_part_order='bae', errors=None): super().__init__(app, prefix=prefix, default_mediatype=default_mediatype, decorators=decorators, catch_all_404s=catch_all_404s, serve_challenge_on_401=serve_challenge_on_401, url_part_order=url_part_order, errors=errors) # name prefix for endpoints self.name = name # configure a customized output_json function so that we can use # Flask's current_app.json_encoder setting self.representations = { 'application/json': output_json, } # registry for individual view functions self._got_registered_once = False self.deferred_functions = [] # automatic serializer handling self.deferred_serializers = [] self.serializers = {} self.serializers_many = {} def _init_app(self, app): super()._init_app(app) self._got_registered_once = True # register individual view functions with the app for deferred in self.deferred_functions: deferred(app) # instantiate serializers for serializer_class in app.serializers.values(): model_name = serializer_class.Meta.model.__name__ self.serializers[model_name] = serializer_class() self.serializers_many[model_name] = serializer_class(many=True) # register serializer overrides for model_name, serializer_class, many in self.deferred_serializers: if many: self.serializers_many[model_name] = serializer_class(many=True) else: self.serializers[model_name] = serializer_class() # attach serializers to Resource instances so that they can perform # automatic deserialization from json requests for resource, _, _ in self.resources: model_name = resource.model.__name__ if model_name not in self.serializers: raise KeyError( f'Could not find a serializer for the {model_name} model!') resource.serializer = self.serializers[model_name] resource.serializer_create = self.serializers[model_name].__class__() resource.serializer_create.context['is_create'] = True self._register_json_encoder(app, self.serializers) def resource(self, *urls, **kwargs): """Decorator to wrap a :class:`~flask_restful.Resource` class, adding it to the api. Parameters are the same as :meth:`~flask_restful.Api.add_resource`. Example:: app = Flask(__name__) api = Api('api', app) @api.resource('/foo') class FooResource(Resource): def get(self): return 'Hello, World!' Overridden to customize the endpoint name """ if urls and isinstance(urls[0], Blueprint): bp = urls[0] urls = (f"{bp.url_prefix or ''}{url}" for url in urls[1:]) def decorator(cls): endpoint = self._get_endpoint(cls, kwargs.pop('endpoint', None)) self.add_resource(cls, *urls, endpoint=endpoint, **kwargs) return cls return decorator
[docs] def model_resource(self, *args, **kwargs): """Decorator to wrap a :class:`backend.api.ModelResource` class, adding it to the api. There are two supported method signatures: `Api.model_resource(model, *urls, **kwargs)` and `Api.model_resource(blueprint, model, *urls, *kwargs)` Example without blueprint:: from backend.extensions.api import api from models import User @api.model_resource(User, '/users', '/users/<int:id>') class UserResource(Resource): def get(self, user): return user def list(self, users): return users Example with blueprint:: from backend.extensions.api import api from models import User from views import bp @api.model_resource(bp, User, '/users', '/users/<int:id>') class UserResource(Resource): def get(self, user): return user def list(self, users): return users """ bp, model, urls = _get_model_resource_args(args) if bp: urls = (f"{bp.url_prefix or ''}{url}" for url in urls) def decorator(cls): cls.model = model endpoint = self._get_endpoint(cls, kwargs.pop('endpoint', None)) self.add_resource(cls, *urls, endpoint=endpoint, **kwargs) return cls return decorator
[docs] def serializer(self, *args, many=False): """Decorator to wrap a :class:`~backend.api.ModelSerializer` class, registering the wrapped serializer as the specific one to use for the serializer's model. For example:: from backend.extensions.api import api from backend.api import ModelSerializer from models import Foo @api.serializer # @api.serializer() works too class FooSerializer(ModelSerializer): class Meta: model = Foo @api.serializer(many=True) class FooListSerializer(ModelSerializer): class Meta: model = Foo """ def decorator(serializer_class): model_name = serializer_class.Meta.model.__name__ self.deferred_serializers.append((model_name, serializer_class, many)) return serializer_class if was_decorated_without_parenthesis(args): return decorator(args[0]) return decorator
[docs] def route(self, *args, **kwargs): """Decorator for registering individual view functions. Usage without blueprint:: api = Api('api', prefix='/api/v1') @api.route('/foo') # resulting url: /api/v1/foo def get_foo(): # do stuff Usage with blueprint:: api = Api('api', prefix='/api/v1') team = Blueprint('team', url_prefix='/team') @api.route(team, '/users') # resulting url: /api/v1/team/users def users(): # do stuff """ bp, url = None, args[0] if isinstance(args[0], Blueprint): bp, url = args[0], args[1] url = f"{bp.url_prefix or ''}{url}" def decorator(fn): endpoint = self._get_endpoint(fn, kwargs.pop('endpoint', None)) self.add_url_rule(url, endpoint, fn, **kwargs) return fn return decorator
def add_url_rule(self, rule, endpoint=None, view_func=None, **kwargs): if not rule.startswith('/'): raise ValueError('URL rule must start with a forward slash (/)') rule = self.prefix + rule self.record( lambda _app: _app.add_url_rule(rule, endpoint, view_func, **kwargs) ) def record(self, fn): if self._got_registered_once: from warnings import warn warn(Warning('The api was already registered once but is getting' ' modified now. These changes will not show up.')) self.deferred_functions.append(fn) def _get_endpoint(self, view_func, endpoint=None, plural=False): if endpoint: assert '.' not in endpoint, 'Api endpoints should not contain dots' elif isinstance(view_func, MethodViewType): endpoint = camel_to_snake_case(view_func.__name__) if hasattr(view_func, 'model') and plural: plural_model = camel_to_snake_case(view_func.model.__plural__) endpoint = f'{plural_model}_resource' else: endpoint = view_func.__name__ return f'{self.name}.{endpoint}' def _register_json_encoder(self, app, serializers): BaseEncoderClass = app.json_encoder or BaseJSONEncoder class JSONEncoder(BaseEncoderClass): def default(self, o): if isinstance(o, enum.Enum): return o.name if isinstance(o, Model): model_name = o.__class__.__name__ if model_name in serializers: return serializers[model_name].dump(o).data return super().default(o) app.json_encoder = JSONEncoder def make_response(self, data, *args, **kwargs): """Overridden to support returning already-formed Responses unmodified, as well as automatic serialization of lists of sqlalchemy models (serialization of individual models is handled by a custom JSONEncoder class configured in the self._register_json_encoder method) """ # we've already got a response, eg, from jsonify if isinstance(data, Response): return (data, *args) if isinstance(data, (list, tuple)) and len(data) and isinstance(data[0], Model): model_name = data[0].__class__.__name__ if model_name in self.serializers_many: data = self.serializers_many[model_name].dump(data).data # we got the result of serializer.dump(obj) if isinstance(data, MarshalResult): data = data.data # we got plain python data types that need to be serialized return super().make_response(data, *args, **kwargs) def _register_view(self, app, resource, *urls, **kwargs): """Overridden to handle custom method names on ModelResources """ if not issubclass(resource, ModelResource) or 'methods' in kwargs: return super()._register_view(app, resource, *urls, **kwargs) for url in urls: endpoint = self._get_endpoint(resource) http_methods = [] has_last_param = get_last_param_name(url) if has_last_param: if ModelResource.has_method(resource, GET): http_methods += ['GET', 'HEAD'] if ModelResource.has_method(resource, DELETE): http_methods += ['DELETE'] if ModelResource.has_method(resource, PATCH): http_methods += ['PATCH'] if ModelResource.has_method(resource, PUT): http_methods += ['PUT'] else: endpoint = self._get_endpoint(resource, plural=True) if ModelResource.has_method(resource, LIST): http_methods += ['GET', 'HEAD'] if ModelResource.has_method(resource, CREATE): http_methods += ['POST'] kwargs['endpoint'] = endpoint super()._register_view(app, resource, url, **kwargs, methods=http_methods)
def output_json(data, code, headers=None): """Replaces Flask-RESTful's default output_json function, using Flask.json's dumps method instead of the stock Python json.dumps. Mainly this means we end up using the current app's configured json_encoder class. """ settings = current_app.config.get('RESTFUL_JSON', {}) # If we're in debug mode, and the indent is not set, we set it to a # reasonable value here. if current_app.debug: settings.setdefault('indent', 4) # always end the json dumps with a new line # see https://github.com/mitsuhiko/flask/pull/1262 dumped = dumps(data, **settings) + '\n' response = make_response(dumped, code) response.headers.extend(headers or {}) return response