diff --git a/caremyway/api/views.py b/caremyway/api/views.py index c19e66f..a037c44 100644 --- a/caremyway/api/views.py +++ b/caremyway/api/views.py @@ -103,14 +103,49 @@ class PriceViewSet(viewsets.ModelViewSet): instance.save() return Response(status=status.HTTP_204_NO_CONTENT) +def validate_param(value, field): + class ValidateParam(serializers.Serializer): + value = field + + if value == 'unspecified': + return value + + obj = ValidateParam(data={'value': value}) + if obj.is_valid(): + return obj.validated_data.get('value') + else: + raise serializers.ValidationError(obj.errors['value']) + +def shift_filter(get_self): + user = get_self.request.user + qp = get_self.request.query_params + completed = validate_param(qp.get('completed', 'unspecified'), serializers.BooleanField()) + approved = validate_param(qp.get('approved', 'unspecified'), serializers.NullBooleanField()) + manage = validate_param(qp.get('manage'), serializers.UUIDField(allow_null=True)) + work_type = validate_param(qp.get('work_type'), serializers.UUIDField(allow_null=True)) + + shifts = Shift.objects.filter(deleted=False) \ + .filter( + Q(price__management__client__user__username=user) + | Q(price__management__provider__user__username=user)) \ + .order_by('approved', 'set_start') + if completed is not 'unspecified': + shifts = shifts.exclude(actual_end__isnull=completed) + if approved is not 'unspecified': + shifts = shifts.filter(approved=approved) + if manage: + shifts = shifts.filter(price__management__uuid=manage) + if work_type: + shifts = shifts.filter(price__work_type__uuid=work_type) + + return shifts + class CShiftViewSet(viewsets.ModelViewSet): lookup_field = 'uuid' serializer_class = CShiftSerializer def get_queryset(self): - return Shift.objects.filter(deleted=False) \ - .filter(price__management__client__user__username=self.request.user) \ - .order_by('-set_start') + return shift_filter(self) def destroy(self, request, *args, **kwargs): instance = self.get_object() @@ -130,19 +165,7 @@ class PShiftViewSet(viewsets.ModelViewSet): http_method_names = ['get', 'head', 'put', 'options'] def get_queryset(self): - return Shift.objects.filter(deleted=False) \ - .filter(price__management__provider__user__username=self.request.user) \ - .order_by('-set_start') - -def validate_param(value, field): - class ValidateParam(serializers.Serializer): - value = field - - obj = ValidateParam(data={'value': value}) - if obj.is_valid(): - return obj.validated_data.get('value'); - else: - raise serializers.ValidationError(obj.errors['value']) + return shift_filter(self) def get_paystart(payday): # Assumes payday is a valided date obj