From 8b48ff9c056423da32227630c0e137de5c390339 Mon Sep 17 00:00:00 2001 From: Tanner Collin Date: Wed, 22 Jan 2020 22:32:58 +0000 Subject: [PATCH] Move utils and permissions to own file, clean up code --- apiserver/apiserver/api/fields.py | 14 ++ apiserver/apiserver/api/permissions.py | 46 ++++++ apiserver/apiserver/api/serializers.py | 204 +++++++++---------------- apiserver/apiserver/api/utils.py | 118 ++++++++++++++ apiserver/apiserver/api/views.py | 76 +++------ 5 files changed, 272 insertions(+), 186 deletions(-) create mode 100644 apiserver/apiserver/api/fields.py create mode 100644 apiserver/apiserver/api/permissions.py diff --git a/apiserver/apiserver/api/fields.py b/apiserver/apiserver/api/fields.py new file mode 100644 index 0000000..d0de9da --- /dev/null +++ b/apiserver/apiserver/api/fields.py @@ -0,0 +1,14 @@ +from rest_framework import serializers +from . import utils + +class UserEmailField(serializers.ModelField): + def to_representation(self, obj): + return getattr(obj.user, 'email', obj.old_email) + def to_internal_value(self, data): + return serializers.EmailField().run_validation(data) + +class HTMLField(serializers.CharField): + def to_internal_value(self, data): + data = utils.clean(data) + return super().to_internal_value(data) + diff --git a/apiserver/apiserver/api/permissions.py b/apiserver/apiserver/api/permissions.py new file mode 100644 index 0000000..46e999f --- /dev/null +++ b/apiserver/apiserver/api/permissions.py @@ -0,0 +1,46 @@ +from rest_framework.permissions import BasePermission, IsAuthenticated, SAFE_METHODS + +class AllowMetadata(BasePermission): + def has_permission(self, request, view): + return request.method in ['OPTIONS', 'HEAD'] + +def is_admin_director(user): + return bool(user.is_staff or user.member.is_director or user.member.is_staff) + +class IsObjOwnerOrAdmin(BasePermission): + def has_object_permission(self, request, view, obj): + return bool(request.user + and (obj.user == request.user + or is_admin_director(request.user) + ) + ) + +class IsSessionInstructorOrAdmin(BasePermission): + def has_object_permission(self, request, view, obj): + return bool(request.user + and (obj.session.instructor == request.user + or is_admin_director(request.user) + ) + ) + +class ReadOnly(BasePermission): + def has_permission(self, request, view): + return bool(request.method in SAFE_METHODS) + def has_object_permission(self, request, view, obj): + return bool(request.method in SAFE_METHODS) + +class IsAdminOrReadOnly(BasePermission): + def has_permission(self, request, view): + return bool( + request.method in SAFE_METHODS + or request.user + and is_admin_director(request.user) + ) + +class IsInstructorOrReadOnly(BasePermission): + def has_permission(self, request, view): + return bool( + request.method in SAFE_METHODS + or request.user + and request.user.member.is_instructor + ) diff --git a/apiserver/apiserver/api/serializers.py b/apiserver/apiserver/api/serializers.py index 6bfccf9..f1777e5 100644 --- a/apiserver/apiserver/api/serializers.py +++ b/apiserver/apiserver/api/serializers.py @@ -5,87 +5,43 @@ from rest_framework.exceptions import ValidationError from rest_framework.validators import UniqueValidator from rest_auth.registration.serializers import RegisterSerializer from rest_auth.serializers import UserDetailsSerializer -from uuid import uuid4 -from PIL import Image -from bleach.sanitizer import Cleaner -from . import models, old_models +from . import models, old_models, fields, utils from .. import settings -#custom_error = lambda x: ValidationError(dict(non_field_errors=x)) - -STATIC_FOLDER = 'data/static/' -LARGE_SIZE = 1080 -MEDIUM_SIZE = 220 -SMALL_SIZE = 110 - -def process_image(upload): - try: - pic = Image.open(upload) - except OSError: - raise serializers.ValidationError('Invalid image file.') - - if pic.format == 'PNG': - ext = '.png' - elif pic.format == 'JPEG': - ext = '.jpg' - else: - raise serializers.ValidationError('Image must be a jpg or png.') - - large = str(uuid4()) + ext - pic.thumbnail([LARGE_SIZE, LARGE_SIZE], Image.ANTIALIAS) - pic.save(STATIC_FOLDER + large) - - medium = str(uuid4()) + ext - pic.thumbnail([MEDIUM_SIZE, MEDIUM_SIZE], Image.ANTIALIAS) - pic.save(STATIC_FOLDER + medium) - - small = str(uuid4()) + ext - pic.thumbnail([SMALL_SIZE, SMALL_SIZE], Image.ANTIALIAS) - pic.save(STATIC_FOLDER + small) - - return small, medium, large - - -ALLOWED_TAGS = [ - 'h3', - 'p', - 'br', - 'strong', - 'em', - 'u', - 'code', - 'ol', - 'li', - 'ul', - 'a', - ] - -clean = Cleaner(tags=ALLOWED_TAGS).clean - - - - - -class UserEmailField(serializers.ModelField): - def to_representation(self, obj): - return getattr(obj.user, 'email', obj.old_email) - def to_internal_value(self, data): - return serializers.EmailField().run_validation(data) - -class HTMLField(serializers.CharField): - def to_internal_value(self, data): - data = clean(data) - return super().to_internal_value(data) - - class TransactionSerializer(serializers.ModelSerializer): - account_type = serializers.ChoiceField(['Interac', 'TD Chequing', 'Paypal', 'Dream Pmt', 'PayPal', 'Square Pmt', 'Member', 'Clearing', 'Cash']) - info_source = serializers.ChoiceField(['Web', 'DB Edit', 'System', 'Receipt or Stmt', 'Quicken Import', 'Paypal IPN', 'PayPal IPN', 'Auto', 'Nexus DB Bulk', 'IPN Trigger', 'Intranet Receipt', 'Automatic', 'Manual']) + # fields directly from old portal. replace with slugs we want + account_type = serializers.ChoiceField([ + 'Interac', + 'TD Chequing', + 'Paypal', + 'Dream Pmt', + 'PayPal', + 'Square Pmt', + 'Member', + 'Clearing', + 'Cash' + ]) + info_source = serializers.ChoiceField([ + 'Web', + 'DB Edit', + 'System', + 'Receipt or Stmt', + 'Quicken Import', + 'Paypal IPN', + 'PayPal IPN', + 'Auto', + 'Nexus DB Bulk', + 'IPN Trigger', + 'Intranet Receipt', + 'Automatic', + 'Manual' + ]) member_id = serializers.IntegerField() member_name = serializers.SerializerMethodField() date = serializers.DateField() + class Meta: model = models.Transaction fields = '__all__' @@ -110,13 +66,21 @@ class TransactionSerializer(serializers.ModelSerializer): return member.preferred_name + ' ' + member.last_name - # member viewing other members class OtherMemberSerializer(serializers.ModelSerializer): status = serializers.SerializerMethodField() + class Meta: model = models.Member - fields = ['id', 'preferred_name', 'last_name', 'status', 'current_start_date', 'photo_small', 'photo_large'] + fields = [ + 'id', + 'preferred_name', + 'last_name', + 'status', + 'current_start_date', + 'photo_small', + 'photo_large' + ] def get_status(self, obj): return 'Former Member' if obj.paused_date else obj.status @@ -125,7 +89,7 @@ class OtherMemberSerializer(serializers.ModelSerializer): class MemberSerializer(serializers.ModelSerializer): status = serializers.SerializerMethodField() photo = serializers.ImageField(write_only=True, required=False) - email = UserEmailField(serializers.EmailField) + email = fields.UserEmailField(serializers.EmailField) phone = serializers.CharField() street_address = serializers.CharField() city = serializers.CharField() @@ -165,7 +129,7 @@ class MemberSerializer(serializers.ModelSerializer): photo = validated_data.get('photo', None) if photo: - small, medium, large = process_image(photo) + small, medium, large = utils.process_image_upload(photo) instance.photo_small = small instance.photo_medium = medium instance.photo_large = large @@ -190,7 +154,6 @@ class AdminMemberSerializer(MemberSerializer): ] - # member viewing member list or search result class SearchSerializer(serializers.Serializer): q = serializers.CharField(write_only=True, max_length=64) @@ -232,14 +195,19 @@ class AdminSearchSerializer(serializers.Serializer): return serializer.data - class CardSerializer(serializers.ModelSerializer): card_number = serializers.CharField(validators=[UniqueValidator( queryset=models.Card.objects.all(), message='Card number already exists.' )]) member_id = serializers.IntegerField() - active_status = serializers.ChoiceField(['card_blocked', 'card_inactive', 'card_member_blocked', 'card_active']) + active_status = serializers.ChoiceField([ + 'card_blocked', + 'card_inactive', + 'card_member_blocked', + 'card_active' + ]) + class Meta: model = models.Card fields = '__all__' @@ -256,15 +224,23 @@ class CardSerializer(serializers.ModelSerializer): return super().create(validated_data) - class TrainingSerializer(serializers.ModelSerializer): - attendance_status = serializers.ChoiceField(['waiting for payment', 'withdrawn', 'rescheduled', 'no-show', 'attended', 'confirmed']) + attendance_status = serializers.ChoiceField([ + 'waiting for payment', + 'withdrawn', + 'rescheduled', + 'no-show', + 'attended', + 'confirmed' + ]) session = serializers.PrimaryKeyRelatedField(queryset=models.Session.objects.all()) student_name = serializers.SerializerMethodField() + class Meta: model = models.Training fields = '__all__' read_only_fields = ['user', 'sign_up_date', 'paid_date', 'member_id'] + def get_student_name(self, obj): if obj.user: member = obj.user.member @@ -277,7 +253,6 @@ class StudentTrainingSerializer(TrainingSerializer): attendance_status = serializers.ChoiceField(['waiting for payment', 'withdrawn']) - class SessionSerializer(serializers.ModelSerializer): student_count = serializers.SerializerMethodField() course_name = serializers.SerializerMethodField() @@ -285,14 +260,18 @@ class SessionSerializer(serializers.ModelSerializer): datetime = serializers.DateTimeField() course = serializers.PrimaryKeyRelatedField(queryset=models.Course.objects.all()) students = TrainingSerializer(many=True, read_only=True) + class Meta: model = models.Session fields = '__all__' read_only_fields = ['old_instructor', 'instructor'] + def get_student_count(self, obj): return len([x for x in obj.students.all() if x.attendance_status != 'withdrawn']) + def get_course_name(self, obj): return obj.course.name + def get_instructor_name(self, obj): if obj.instructor and hasattr(obj.instructor, 'member'): name = '{} {}'.format(obj.instructor.member.preferred_name, obj.instructor.member.last_name[0]) @@ -304,7 +283,6 @@ class SessionListSerializer(SessionSerializer): students = None - class CourseSerializer(serializers.ModelSerializer): class Meta: model = models.Course @@ -313,13 +291,12 @@ class CourseSerializer(serializers.ModelSerializer): class CourseDetailSerializer(serializers.ModelSerializer): sessions = SessionListSerializer(many=True, read_only=True) name = serializers.CharField(max_length=100) - description = HTMLField(max_length=6000) + description = fields.HTMLField(max_length=6000) class Meta: model = models.Course fields = '__all__' - class UserTrainingSerializer(serializers.ModelSerializer): session = SessionListSerializer() class Meta: @@ -334,7 +311,15 @@ class UserSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ['id', 'username', 'member', 'transactions', 'cards', 'training', 'is_staff'] + fields = [ + 'id', + 'username', + 'member', + 'transactions', + 'cards', + 'training', + 'is_staff' + ] depth = 1 def get_transactions(self, obj): @@ -346,15 +331,6 @@ class UserSerializer(serializers.ModelSerializer): return serializer.data -def request_from_protospace(request): - whitelist = ['24.66.110.96', '205.233.15.76', '205.233.15.69'] - - # set (not appended) directly by nginx so we can trust it - real_ip = request.META.get('HTTP_X_REAL_IP', False) - - return real_ip in whitelist - - class RegistrationSerializer(RegisterSerializer): first_name = serializers.CharField(max_length=32) last_name = serializers.CharField(max_length=32) @@ -365,44 +341,12 @@ class RegistrationSerializer(RegisterSerializer): is_test_signup = bool(settings.DEBUG and data['last_name'] == 'tester') - if not request_from_protospace(request) and not is_test_signup: + if not utils.is_request_from_protospace(request) and not is_test_signup: user.delete() raise ValidationError(dict(non_field_errors='Can only register from Protospace.')) if data['existing_member'] == 'true': - old_members = old_models.Members.objects.using('old_portal') - try: - old_member = old_members.get(email=data['email']) - except old_models.Members.DoesNotExist: - user.delete() - raise ValidationError(dict(email='Unable to find email in old database.')) - - member = models.Member.objects.get(id=old_member.id) - - if member.user: - raise ValidationError(dict(email='Old member already claimed.')) - - member.user = user - member.first_name = data['first_name'] - member.last_name = data['last_name'] - member.preferred_name = data['first_name'] - member.save() - - transactions = models.Transaction.objects.filter(member_id=member.id) - for t in transactions: - t.user = user - t.save() - - cards = models.Card.objects.filter(member_id=member.id) - for c in cards: - c.user = user - c.save() - - training = models.Training.objects.filter(member_id=member.id) - for t in training: - t.user = user - t.save() - + utils.link_old_member(data, user) else: models.Member.objects.create( user=user, diff --git a/apiserver/apiserver/api/utils.py b/apiserver/apiserver/api/utils.py index 850e098..e4c8a26 100644 --- a/apiserver/apiserver/api/utils.py +++ b/apiserver/apiserver/api/utils.py @@ -1,5 +1,8 @@ import datetime from dateutil import relativedelta +from uuid import uuid4 +from PIL import Image +from bleach.sanitizer import Cleaner from django.db.models import Sum @@ -107,3 +110,118 @@ def tally_membership_months(member, fake_date=None): member.save() return True + + +search_strings = {} +def gen_search_strings(): + ''' + Generate a cache dict of names to member ids for rapid string matching + ''' + for m in models.Member.objects.all(): + string = '{} {}'.format( + m.preferred_name, + m.last_name, + ).lower() + search_strings[string] = m.id + + +STATIC_FOLDER = 'data/static/' +LARGE_SIZE = 1080 +MEDIUM_SIZE = 220 +SMALL_SIZE = 110 + +def process_image_upload(upload): + ''' + Save an image upload in small, medium, large sizes and return filenames + ''' + try: + pic = Image.open(upload) + except OSError: + raise serializers.ValidationError('Invalid image file.') + + if pic.format == 'PNG': + ext = '.png' + elif pic.format == 'JPEG': + ext = '.jpg' + else: + raise serializers.ValidationError('Image must be a jpg or png.') + + large = str(uuid4()) + ext + pic.thumbnail([LARGE_SIZE, LARGE_SIZE], Image.ANTIALIAS) + pic.save(STATIC_FOLDER + large) + + medium = str(uuid4()) + ext + pic.thumbnail([MEDIUM_SIZE, MEDIUM_SIZE], Image.ANTIALIAS) + pic.save(STATIC_FOLDER + medium) + + small = str(uuid4()) + ext + pic.thumbnail([SMALL_SIZE, SMALL_SIZE], Image.ANTIALIAS) + pic.save(STATIC_FOLDER + small) + + return small, medium, large + + +ALLOWED_TAGS = [ + 'h3', + 'p', + 'br', + 'strong', + 'em', + 'u', + 'code', + 'ol', + 'li', + 'ul', + 'a', +] + +clean = Cleaner(tags=ALLOWED_TAGS).clean + + +def is_request_from_protospace(request): + whitelist = ['24.66.110.96', '205.233.15.76', '205.233.15.69'] + + # set (not appended) directly by nginx so we can trust it + real_ip = request.META.get('HTTP_X_REAL_IP', False) + + return real_ip in whitelist + +def link_old_member(data, user): + ''' + If a member claims they have an account on the old protospace portal, + go through and link their objects to their new user using the member_id + found with their email as a hint + ''' + old_members = old_models.Members.objects.using('old_portal') + + try: + old_member = old_members.get(email=data['email']) + except old_models.Members.DoesNotExist: + user.delete() + raise ValidationError(dict(email='Unable to find email in old database.')) + + member = models.Member.objects.get(id=old_member.id) + + if member.user: + raise ValidationError(dict(email='Old member already claimed.')) + + member.user = user + member.first_name = data['first_name'] + member.last_name = data['last_name'] + member.preferred_name = data['first_name'] + member.save() + + transactions = models.Transaction.objects.filter(member_id=member.id) + for t in transactions: + t.user = user + t.save() + + cards = models.Card.objects.filter(member_id=member.id) + for c in cards: + c.user = user + c.save() + + training = models.Training.objects.filter(member_id=member.id) + for t in training: + t.user = user + t.save() diff --git a/apiserver/apiserver/api/views.py b/apiserver/apiserver/api/views.py index 3a11ed1..cabe35a 100644 --- a/apiserver/apiserver/api/views.py +++ b/apiserver/apiserver/api/views.py @@ -1,4 +1,3 @@ -import datetime from django.contrib.auth.models import User, Group from django.shortcuts import get_object_or_404 from django.db.models import Max @@ -10,48 +9,20 @@ from rest_auth.views import PasswordChangeView from rest_auth.registration.views import RegisterView from fuzzywuzzy import fuzz, process from collections import OrderedDict +import datetime from . import models, serializers, utils +from .permissions import ( + is_admin_director, + AllowMetadata, + IsObjOwnerOrAdmin, + IsSessionInstructorOrAdmin, + ReadOnly, + IsAdminOrReadOnly, + IsInstructorOrReadOnly +) -class AllowMetadata(BasePermission): - def has_permission(self, request, view): - return request.method in ['OPTIONS', 'HEAD'] - -def is_admin_director(user): - return bool(user.is_staff or user.member.is_director or user.member.is_staff) - -class IsObjOwnerOrAdmin(BasePermission): - def has_object_permission(self, request, view, obj): - return bool(request.user and (obj.user == request.user or is_admin_director(request.user))) - -class IsSessionInstructorOrAdmin(BasePermission): - def has_object_permission(self, request, view, obj): - return bool(request.user and (obj.session.instructor == request.user or is_admin_director(request.user))) - -class ReadOnly(BasePermission): - def has_permission(self, request, view): - return bool(request.method in SAFE_METHODS) - def has_object_permission(self, request, view, obj): - return bool(request.method in SAFE_METHODS) - -class IsAdminOrReadOnly(BasePermission): - def has_permission(self, request, view): - return bool( - request.method in SAFE_METHODS or - request.user and - is_admin_director(request.user) - ) - -class IsInstructorOrReadOnly(BasePermission): - def has_permission(self, request, view): - return bool( - request.method in SAFE_METHODS or - request.user and - request.user.member.is_instructor - ) - - - +# define some shortcuts Base = viewsets.GenericViewSet List = mixins.ListModelMixin Retrieve = mixins.RetrieveModelMixin @@ -59,18 +30,9 @@ Create = mixins.CreateModelMixin Update = mixins.UpdateModelMixin Destroy = mixins.DestroyModelMixin - - -search_strings = {} -def gen_search_strings(): - for m in models.Member.objects.all(): - string = '{} {}'.format( - m.preferred_name, - m.last_name, - ).lower() - search_strings[string] = m.id - NUM_SEARCH_RESULTS = 10 + + class SearchViewSet(Base, Retrieve): permission_classes = [AllowMetadata | IsAuthenticated] @@ -84,11 +46,11 @@ class SearchViewSet(Base, Retrieve): queryset = models.Member.objects.all() search = self.request.data.get('q', '').lower() - if not search_strings: - gen_search_strings() # init cache + if not utils.search_strings: + utils.gen_search_strings() # init cache if len(search): - choices = search_strings.keys() + choices = utils.search_strings.keys() # get exact starts with matches results = [x for x in choices if x.startswith(search)] @@ -103,12 +65,12 @@ class SearchViewSet(Base, Retrieve): # remove dupes, truncate list results = list(OrderedDict.fromkeys(results))[:NUM_SEARCH_RESULTS] - result_ids = [search_strings[x] for x in results] + result_ids = [utils.search_strings[x] for x in results] result_objects = [queryset.get(id=x) for x in result_ids] queryset = result_objects elif self.action == 'create': - gen_search_strings() # update cache + utils.gen_search_strings() # update cache queryset = queryset.order_by('-vetted_date') return queryset @@ -200,6 +162,7 @@ class SessionViewSet(Base, List, Retrieve, Create, Update): def perform_create(self, serializer): serializer.save(instructor=self.request.user) + class TrainingViewSet(Base, Retrieve, Create, Update): permission_classes = [AllowMetadata | IsAuthenticated, IsObjOwnerOrAdmin | IsSessionInstructorOrAdmin | ReadOnly] serializer_class = serializers.TrainingSerializer @@ -281,5 +244,6 @@ class DoorViewSet(Base, List): class RegistrationView(RegisterView): serializer_class = serializers.RegistrationSerializer + class PasswordChangeView(PasswordChangeView): permission_classes = [AllowMetadata | IsAuthenticated]