python Django 之 DRF(五)GenericAPIView类、子类的源码分析


前言

随着数据的增多,后端api接口需要大量的定义,普通的情况通过restframework框架中的APIVIew可以解决很多问题,但是当接口需求大的时候,为了方便DRF框架提供了许多的内置函数,在我们的上篇文章中就有提及到的各种框架比如:

  1. 继承于APIView方法的GenericAPIView类
  2. 继承于ViewSetMixin类、GenericAPIView类的GenericViewSet方法
  3. 继承了多个增删改查类、GenericViewSet类,并重写了as_view方法ModelViewSet方法

等等许多的构造类方法提供给了我们方便使用,但是在这个基础之上,还有一个更为方便的类提供我们使用,并且该类继承了GenericAPIView类以及其增删改查类,该类型的方法无需重写as_view,也有自己的一套写法,方便简洁,并且每个子类函数都非常的类似,下面我们就通过阅读部分源码的方式来实现该类的用法吧!。




一、ListAPIView类源码分析

ListAPIView类继承于ListModelMixin类和GenericAPIView类,该方法无需重写as_view()方法,且功能是获取数据、转义数据、以及数据校验、返回数据、解析数据等一系列父类拥有的功能都有,可供我们在展示数据库数据据时非常方便,那么我们就通过查看源码来分析一下使用规则吧。


ListAPIView类源码如下:
class ListAPIView(mixins.ListModelMixin,
                  GenericAPIView)

    def get(self, request, *args, **kwargs):
        return self.list(request, *args, **kwargs)

从中我们可以看出ListAPIView重写了get方法,并且在get方法中调用了list函数并将数据放入,那么我们就从list函数查看( 子类没有就从父类找,从左往右的顺序依次查找),从这里我们可以发现如果当我们当前视图类自定义了list函数时,我们可以通过 super().list(request, *args, **kwargs)获取到父类中的list函数,将返回的值并进行自定义校验、赋值的操作,因为这里我们并没有重写,所以我们点进该父类的ListModelMixin方法中。

ListModelMixin类源码如下:

class ListModelMixin:
    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())
        page = self.paginate_queryset(queryset)
        if page is not None:
            serializer = self.get_serializer(page, many=True)
            return self.get_paginated_response(serializer.data)

        serializer = self.get_serializer(queryset, many=True)
        return Response(serializer.data)

可以发现ListModelMixin类中确实有list函数,并且从中我们看到了该函数调用了 self.get_queryset()、self.filter_queryset()的函数并将返回结果赋值给了queryset,可是在ListModelMixin类中并没有这两个类,所以我们回到原先查找,这里可以发现如果我们在自定义的视图类中写入了get_queryset()函数方法,那么就可以进行重写操作,那么我们来到了父类GenericAPIView类中。

在上篇文章中我只讲述了如何去使用GenericAPIView类,但是为了让自己能更加灵活的运用,且配合今天要讲的继承类来分析GenericAPIView类的源码,或许能让自己在定义接口时有其他的思路和更方便的方法。
GenericAPIView类如下:
class GenericAPIView(views.APIView):

    queryset = None
    serializer_class = None
    lookup_field = 'pk'
    lookup_url_kwarg = None
    filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
    pagination_class = api_settings.DEFAULT_PAGINATION_CLASS

    def get_queryset(self):
   
        assert self.queryset is not None, (
                "'%s' should either include a `queryset` attribute, "
                "or override the `get_queryset()` method."
                % self.__class__.__name__
        )

        queryset = self.queryset
        if isinstance(queryset, QuerySet):
            # Ensure queryset is re-evaluated on each request.
            queryset = queryset.all()
        return queryset

    def get_object(self):

        queryset = self.filter_queryset(self.get_queryset())

        # Perform the lookup filtering.
        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field

        assert lookup_url_kwarg in self.kwargs, (
                'Expected view %s to be called with a URL keyword argument '
                'named "%s". Fix your URL conf, or set the `.lookup_field` '
                'attribute on the view correctly.' %
                (self.__class__.__name__, lookup_url_kwarg)
        )

        filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
        obj = get_object_or_404(queryset, **filter_kwargs)

        # May raise a permission denied
        self.check_object_permissions(self.request, obj)

        return obj

    def get_serializer(self, *args, **kwargs):
        serializer_class = self.get_serializer_class()
        kwargs.setdefault('context', self.get_serializer_context())
        return serializer_class(*args, **kwargs)

    def get_serializer_class(self):
        assert self.serializer_class is not None, (
                "'%s' should either include a `serializer_class` attribute, "
                "or override the `get_serializer_class()` method."
                % self.__class__.__name__
        )

        return self.serializer_class

    def get_serializer_context(self):
        return {
            'request': self.request,
            'format': self.format_kwarg,
            'view': self
        }

    def filter_queryset(self, queryset):
        for backend in list(self.filter_backends):
            queryset = backend().filter_queryset(self.request, queryset, self)
        return queryset

    @property
    def paginator(self):
 
        if not hasattr(self, '_paginator'):
            if self.pagination_class is None:
                self._paginator = None
            else:
                self._paginator = self.pagination_class()
        return self._paginator

    def paginate_queryset(self, queryset):
   
        if self.paginator is None:
            return None
        return self.paginator.paginate_queryset(queryset, self.request, view=self)

    def get_paginated_response(self, data):
    
        assert self.paginator is not None
        return self.paginator.get_paginated_response(data)

可以看到GenericAPIView方法还是有许多的函数的,此时我们先通过查看get_queryset函数开始吧。

get_queryset函数如下:

 def get_queryset(self):
        assert self.queryset is not None, (
                "'%s' should either include a `queryset` attribute, "
                "or override the `get_queryset()` method."
                % self.__class__.__name__
        )

        queryset = self.queryset
        if isinstance(queryset, QuerySet):
            # Ensure queryset is re-evaluated on each request.
            queryset = queryset.all()
        return queryset

可以发现该方法获取了我们的self.queryset参数,所以在定义ListAPIView的时候需要传入一个数据库查询对象,并且判断了 类型是否为QuerySet类型,如果是的则添加.all()函数,否指直接返回queryset,这表明了我们在 自定义queryset查询时,可以不用写入后面的获取所有条件,看到这里我们就能明白get_queryset函数其实就是帮我们从数据库获取我们想要的参数,那么我们如果想重写get_queryset函数的话也需按照该规则实现即可,且可以无需在自定义视图类中增加queryset对象。

1.BaseFilterBackend类的使用

filter_queryset函数如下:

 def filter_queryset(self, queryset):
      
        for backend in list(self.filter_backends):
            queryset = backend().filter_queryset(self.request, queryset, self)
        return queryset

从中看出该函数变量了self.filter_backends对象,且该对象为配置文件中的筛选器(可以在settings.py中配置上筛选器,也可以使用DRF内置继承类BaseFilterBackend方法来搭配使用),然后通过遍历的方法,调用子类的filter_queryset函数(这是我们自定义筛选函数中的filter_backends函数而当前类中传入了self.request, queryset, self这些参数到子类中),作用是将在我们自定义queryset对象是使用filter来筛选数据库对象,从而获取到筛选后的对象并赋值给了queryset并返回

BaseFilterBackend类源码如下:

class BaseFilterBackend:

    def filter_queryset(self, request, queryset, view):
     
        raise NotImplementedError(".filter_queryset() must be overridden.")

    def get_schema_fields(self, view):
        assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
        assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
        return []

    def get_schema_operation_parameters(self, view):
        return []

对于BaseFilterBackend类中还有继承于该类的子类(SearchFilter,OrderingFilter类。这里我们就不探讨了,有兴趣的朋友可以自行查看),我们从中可以看出BaseFilterBackend在使用时有个filter_queryset函数,该函数就是我们在自定义筛选类中重写用的,参数即为父类传入的参数。


那么此时我们就可以知道了list函数中的queryset即为数据库查询、筛选完展示出的对象值,此时返回list函数继续往下走。

ListModelMixin类源码如下:

class ListModelMixin:
    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())
        page = self.paginate_queryset(queryset)
        if page is not None:
            serializer = self.get_serializer(page, many=True)
            return self.get_paginated_response(serializer.data)

        serializer = self.get_serializer(queryset, many=True)
        return Response(serializer.data)

可以发现此时将在数据库筛选完的对象给了分页paginate_queryset函数,并且将该返回的参数给了page(即为分页完成后展示的数据),并判断了page是否存在,如果存在那么则调用了get_serializer函数将page带上,many=True(表示传递的数据有多个),在通过get_paginated_response函数将分页完成的数据带上响应值、序列化的数据内部函数通过Response函数解析返回,而如果没有分页则调用get_serializer函数并直接传入queryset对象,在通过Response函数返回,因为之前也讲过了分页操作,所以这里不详细讲,所以我们只查看
get_serializer相关的函数的源码。


get_serializer相关的函数如下:

def get_serializer(self, *args, **kwargs):
        serializer_class = self.get_serializer_class()
        kwargs.setdefault('context', self.get_serializer_context())
        return serializer_class(*args, **kwargs)

    def get_serializer_class(self):
        assert self.serializer_class is not None, (
                "'%s' should either include a `serializer_class` attribute, "
                "or override the `get_serializer_class()` method."
                % self.__class__.__name__
        )

        return self.serializer_class

    def get_serializer_context(self):
        return {
            'request': self.request,
            'format': self.format_kwarg,
            'view': self
        }

从中我们可以看出,该函数调用了get_serializer_class来判断是否在 self.serializer_class序列化类中,存在就返回,然后调用了get_serializer_context函数,该函数将其封装,其中有request,view,format参数(这里比较关键,表明如果当我们的视图中有继承GenericAPIView类时,那么我们的序列化serializer_class类中即拥有了我们父类的所有参数,且是通过context关键字获取。),最后通过kwargs存入,之后返回serializer_class(*args, **kwargs)函数。


那么这就是ListAPIView方法的流程了,那么我们做一下总结:

  • 1.子类可以通过get_queryset函数重写queryset中的查询对象
  • 2.子类通过重写list函数的基础上给父类返回的参数在进行验证、赋值
  • 3.通过继承BaseFilterBackend类来重写filter_queryset函数实现自定义筛选方法



二、RetrieveAPIView类源码分析

RetrieveAPIView类和ListAPIView类似都继承了一个GenericAPIView父类,不过RetrieveAPIView类还继承了一个RetrieveModelMixin类和ListModelMixin类不同它可以获取请求url路径上的参数来通过数据库进行筛选取值,不过他们两者都是通过重写get方法获取,所以在自定义接口类中是不能共同使用的。

RetrieveAPIView类如下:

class RetrieveAPIView(mixins.RetrieveModelMixin,
                      GenericAPIView):

    def get(self, request, *args, **kwargs):
        return self.retrieve(request, *args, **kwargs)

同样的重写于get请求,都是返回的是retrieve函数,是父类RetrieveModelMixin类的函数,此时我们点进去查看。


RetrieveModelMixin函数如下:

class RetrieveModelMixin:

    def retrieve(self, request, *args, **kwargs):
        instance = self.get_object()
        serializer = self.get_serializer(instance)
        return Response(serializer.data)

可以发现retrieve函数方法和list方法特别类似,不过在序列化参数时,该函数调用的是get_object()函数,此时我们发现该函数也是GenericAPIView函数的方法,那么我们点进去查看一下。

GenericAPIView类/get_object函数如下:

class GenericAPIView(views.APIView):

    queryset = None
    serializer_class = None
    lookup_field = 'pk'
    lookup_url_kwarg = None
    filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
    pagination_class = api_settings.DEFAULT_PAGINATION_CLASS


    def get_object(self):

        queryset = self.filter_queryset(self.get_queryset())

        # Perform the lookup filtering.
        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field

        assert lookup_url_kwarg in self.kwargs, (
                'Expected view %s to be called with a URL keyword argument '
                'named "%s". Fix your URL conf, or set the `.lookup_field` '
                'attribute on the view correctly.' %
                (self.__class__.__name__, lookup_url_kwarg)
        )

        filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
        obj = get_object_or_404(queryset, **filter_kwargs)

        # May raise a permission denied	
        self.check_object_permissions(self.request, obj)

        return obj

从第一步可以看出和我们的ListAPIView类相同,先获取筛选后的参数然后赋值给了queryset,不过在后面get_object函数调用了(self.lookup_url_kwarg or self.lookup_field一个键一个值,键值默认"pk"),然后进行了判断是否在self.kwargs中(即在定义url上面的参数),然后将获取到的参数放入到filter_kwargs中({pk:“id”}),然后调用了get_object_or_404来判断前端传来的url同上的数据是否有误,之后又进行了一次筛选,将对象返回给了obj,之后就判断了一下是否有自定义的权限函数,如果有就进行权限判断,之后在返回


那么这就是RetrieveAPIView源码方法的流程了,和ListAPIView类似,那么我们做一下总结:

  • 1.在通过函数get_object中获取queryset的基础上增加了一个路由id的筛选
  • 2.增加了一个权限认证的函数校验



三、CreateAPIView类源码分析

CreateAPIView类和上面的获取值的操作就稍有不同了,不过他们都有共同的GenericAPIView父类,并且继承了CreateModelMixin类,用于保存数据的操作,发送请求的函数是POST请求。

CreateAPIView类如下:

class CreateAPIView(mixins.CreateModelMixin,
                    GenericAPIView):
    def post(self, request, *args, **kwargs):
        return self.create(request, *args, **kwargs)

此时我们可以发现CreateAPIView类是重写了post函数的,并调用了self.create函数,在CreateModelMixin类中,所以我们点进去查看一下。

CreateModelMixin类如下:

class CreateModelMixin:

    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        headers = self.get_success_headers(serializer.data)
        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

    def perform_create(self, serializer):
        serializer.save()

    def get_success_headers(self, data):
        try:
            return {'Location': str(data[api_settings.URL_FIELD_NAME])}
        except (TypeError, KeyError):
            return {}

可以看到方法从create函数走过,第一个调用了get_serializer,和前面的不同,这次的调用方法是赋值给了data,而request.data就是我们通过POST请求发来的数据并解析给我们的数据,因为我们读过get_serializer类的源码,知道其实该方法就是通过获取我们的serializer_class的自定义序列化类,并通过data的形式传入,进行表单验证,将校验完成的对象返回给serializer ,之后在调用is_valid来判断是否验证通过,通过之后调用了perform_create函数进行保存( 这里我们可以自定义子类中重写来实现保存前添加点其他参数在保存,或在保存前判断,保存后面数据的方法为serializer.save(数据库名="参数")),最后在通过调用get_success_headers获取成功后的请求头返回给headers,最后通过Response方法返回。

对于ModelSerializer类来说,该类拥有一个create函数,并且传入的参数为self, validated_data,表示该类有父类的方法,而validated_data则是ModelSerializer通过序列化,进行表单验证成功后的结果,往往我们可以通过在这个验证成功的结果上添加一些其他的验证规则,以及添加其他的数据,此时我们来查看一下create函数的源码:


2.ModelSerializer类中的扩展

ModelSerializer类/create函数如下:

    def create(self, validated_data):
        raise_errors_on_nested_writes('create', self, validated_data)

        ModelClass = self.Meta.model

        info = model_meta.get_field_info(ModelClass)
        many_to_many = {}
        for field_name, relation_info in info.relations.items():
            if relation_info.to_many and (field_name in validated_data):
                many_to_many[field_name] = validated_data.pop(field_name)

        try:
            instance = ModelClass._default_manager.create(**validated_data)
        except TypeError:
            tb = traceback.format_exc()
            msg = (
                'Got a `TypeError` when calling `%s.%s.create()`. '
                'This may be because you have a writable field on the '
                'serializer class that is not a valid argument to '
                '`%s.%s.create()`. You may need to make the field '
                'read-only, or override the %s.create() method to handle '
                'this correctly.\nOriginal exception was:\n %s' %
                (
                    ModelClass.__name__,
                    ModelClass._default_manager.name,
                    ModelClass.__name__,
                    ModelClass._default_manager.name,
                    self.__class__.__name__,
                    tb
                )
            )
            raise TypeError(msg)

        if many_to_many:
            for field_name, value in many_to_many.items():
                field = getattr(instance, field_name)
                field.set(value)

        return instance

这里的大概意思应该是通过get_field_info获取字段的信息,然后定义一个many_to_many 的字典来存放我们字段的键和值,然后通过字段的create函数来创建值,并返回给instance,此时instance获取到了刚创建好的对象,最后判断many_to_many是否存在,存在则获取字段的方法,然后在通过set函数存入,最后将序列化完毕的instance对象返回。所以我们只需要在重写create函数的时候返回通过创建得到的对象即可。


假设前端返回给我们的数据如下:

{
        "user": {
            "nickname": "sehun粉",
            "phone":'1234567890',
        },
        "cover": "https://wx-1304867879.cos.ap-guangzhou.myqcloud.com/aedbGbQfa8fZerPNSdHAbAGbx4cNncacwx.jpg",
        "content": "哔哔哔",
        "address": "北京四合院",
        "favor_count": 0,
        "viewer_count": 0,
        "comment_count": 0,
        "create_date": "2021-07-11T15:42:22.154529Z"
}

此时我们分析一下需求:

  1. 将这user的数据和其他的数据分别保存在不同的表中
  2. user对象需要token才能添加
  3. Address表结构还需要用户的对象
  4. Address对象最后需要让viewer_count+1并保存

serializer.py如下:

from rest_framework import serializers
from api import models
import uuid
class UserInfo(serializers.Serializer):
    nickname= serializers.CharField()
    phone= serializers.CharField()


class Address(serializers.ModelSerializer):
    user= UserInfo(many=True)  # 传入一个字典userinfo,包含了nickname和phone

    class Meta:
        model = models.Address
        fields = "__all__"
 	def create(self, validated_data):
       user_list = validated_data.pop('user')
       # 获取用户信息
       user_list['token'] = str(uuid.uuid4())
       # 保存到UserInfo表
       user_object = models.UserInfo.objects.create(**user_list )
       validated_data['user'] = user_object
       # 保存到Address表中
       address_object = models.Address.objects.create(**validated_data)

       return address_object 

此时继承了CreateAPIView类的自定义视图类AddressView.py如下:

from .serializer.py import Address
from django.db.models import F
class AddressView(CreateAPIView):
    serializer_class = Address
    # 通过CreateAPIView先执行perform_create调用sava方法通过create保存
    def perform_create(self, serializer):
        return serializer.save(viewer_count=F('viewer_count') + 1)

那么我们来进行一下CreateAPIView类的总结:

  1. 可以在自定义类的基础上重写perform_create,在此进行一些判断条件或增加参数
  2. 可以配合着ModelSerializer类中create方法结合使用



四、UpdateAPIView类源码分析

UpdateAPIView类和CreateAPIView比较类似,也是一样和上面几个函数同样的继承了GenericAPIView,不过不同的是第二个父类继承了UpdateModelMixin类,此时我们先来看一下UpdateAPIView的源码吧。

class UpdateAPIView(mixins.UpdateModelMixin,
                    GenericAPIView):

    def put(self, request, *args, **kwargs):
        return self.update(request, *args, **kwargs)

    def patch(self, request, *args, **kwargs):
        return self.partial_update(request, *args, **kwargs)

这次我们可以发现,他和其他原先的不太一样,有两个重写的函数put、patch,不过这两个函数调用的都是UpdateModelMixin中的函数update、partial_update。它们的区别在于当前端发送了patch请求表示着我创来的数据只修改一部分,而发送put请求就表示修改全部。


此时我们点开UpdateModelMixin类源码如下:

class UpdateModelMixin:
    def update(self, request, *args, **kwargs):
        partial = kwargs.pop('partial', False)
        instance = self.get_object()
        serializer = self.get_serializer(instance, data=request.data, partial=partial)
        serializer.is_valid(raise_exception=True)
        self.perform_update(serializer)

        if getattr(instance, '_prefetched_objects_cache', None):
            instance._prefetched_objects_cache = {}
        return Response(serializer.data)

    def perform_update(self, serializer):
        serializer.save()

    def partial_update(self, request, *args, **kwargs):
        kwargs['partial'] = True
        return self.update(request, *args, **kwargs)

从中我们可以发现partial_update(即为发送patch请求的时候),它实际上也是调用了update函数,不过是将kwargs[‘partial’] = True,即部分修改。所以我们把重心放到update函数中,可以发现一开始该函数就通过kwargs.pop(‘partial’, False)来取值,如果没有获取到则设置为False并赋值给了partial,然后通过get_object函数获取原先数据,之后再通过get_serializer函数传入原先的数据和要修改的数据,和partial传入,然后进行表单验证,通过之后保存修改过的数据,所以我们可以通过在准备修改的基础上重写perform_update,在此进行一些判断条件或增加参数


那么我们来进行一下UpdateAPIView类的总结:

  1. 通过调用get_object函数获取原先数据(有需要也可以重写get_queryset函数)
  2. 通过请求函数不同改变partial的值(布尔类型用于判断是否全部修改)
  3. 可以在自定义类的基础上重写perform_update,在此进行一些判断条件或增加参数



restframework非常的强大,有许多的内置框架,并且可以搭配使用,对于我们这篇所讲的也只是冰山一角,还有许多的方法等大家挖掘,这里其实还有一个DestroyAPIView方法我们没讲,是一个调用delete函数的方法,大家可以自行通过源码阅读,和我们上面的方法基本类似,大家也能发现,当我们明白其中一种内置类方法后,其余类似的内置类其实也就没什么不同的了。

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值