diff --git a/authserver/authserver/api/models.py b/authserver/authserver/api/models.py index 2b56cd5..815624f 100644 --- a/authserver/authserver/api/models.py +++ b/authserver/authserver/api/models.py @@ -10,6 +10,13 @@ class Category(models.Model): def __str__(self): return self.name +class Firmware(models.Model): + version = models.CharField(unique=True, max_length=4) + binary = models.FileField() + + def __str__(self): + return self.version + class Tool(models.Model): category = models.ForeignKey(Category, related_name='tools', on_delete=models.PROTECT) name = models.CharField(max_length=32) @@ -18,12 +25,13 @@ class Tool(models.Model): wiki_id = models.IntegerField() photo = models.ImageField(blank=True) mac = models.CharField(max_length=12) + firmware = models.ForeignKey(Firmware, blank=True, null=True, related_name='tools', on_delete=models.SET_NULL) def __str__(self): return self.name class Course(models.Model): - name = models.CharField(max_length=32) + name = models.CharField(max_length=64) slug = models.CharField(max_length=32, unique=True) tools = models.ManyToManyField(Tool, blank=True) @@ -34,6 +42,7 @@ class Profile(models.Model): user = models.OneToOneField(User, on_delete=models.CASCADE, editable=False) lockout_admin = models.BooleanField(default=False) courses = models.ManyToManyField(Course, blank=True) + selected_courses = models.BooleanField(default=False) def __str__(self): return self.user.username diff --git a/authserver/authserver/api/serializers.py b/authserver/authserver/api/serializers.py index 5990ab8..af520a8 100644 --- a/authserver/authserver/api/serializers.py +++ b/authserver/authserver/api/serializers.py @@ -3,6 +3,8 @@ from rest_framework import serializers from . import models +from authserver.settings import FIRMWARE_VERSION_MAGIC + class CategorySerializer(serializers.HyperlinkedModelSerializer): url = serializers.HyperlinkedIdentityField(view_name='category-detail', lookup_field='slug') @@ -13,6 +15,7 @@ class CategorySerializer(serializers.HyperlinkedModelSerializer): class CourseSerializer(serializers.HyperlinkedModelSerializer): url = serializers.HyperlinkedIdentityField(view_name='course-detail', lookup_field='slug') tools = serializers.SlugRelatedField( + allow_null=True, many=True, slug_field='slug', queryset=models.Tool.objects.all() @@ -29,12 +32,17 @@ class ToolSerializer(serializers.HyperlinkedModelSerializer): lookup_field='slug', queryset=models.Category.objects.all() ) + firmware = serializers.SlugRelatedField( + allow_null=True, + slug_field='version', + queryset=models.Firmware.objects.all().order_by('-version') + ) class Meta: model = models.Tool fields = '__all__' -class ToolDataSerializer(serializers.HyperlinkedModelSerializer): +class CategoryToolSerializer(serializers.HyperlinkedModelSerializer): url = serializers.HyperlinkedIdentityField(view_name='category-detail', lookup_field='slug') tools = ToolSerializer(many=True) @@ -66,3 +74,31 @@ class UserSerializer(serializers.ModelSerializer): model = User fields = ('username', 'profile') depth = 1 + +class FirmwareSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='firmware-detail', lookup_field='version') + version = serializers.CharField(read_only=True) + tools = serializers.StringRelatedField(read_only=True, many=True) + + class Meta: + model = models.Firmware + fields = '__all__' + + def create(self, validated_data): + binary = validated_data['binary'].read().decode('ascii', 'replace') + + if binary.count(FIRMWARE_VERSION_MAGIC) != 2 or 'setup()' in binary: + raise serializers.ValidationError('Uploaded binary not a valid lockout firmware.') + + try: + binary_parts = binary.split(FIRMWARE_VERSION_MAGIC) + version = binary_parts[1].strip() + _ = int(version) + except: + raise serializers.ValidationError('Unable to extract firmware version.') + + if models.Firmware.objects.filter(version=version).exists(): + raise serializers.ValidationError('Firmware version already exists.') + + validated_data['version'] = version + return serializers.ModelSerializer.create(self, validated_data) diff --git a/authserver/authserver/api/views.py b/authserver/authserver/api/views.py index 50f678f..2908e8c 100644 --- a/authserver/authserver/api/views.py +++ b/authserver/authserver/api/views.py @@ -5,6 +5,8 @@ import struct import time from django.contrib.auth.models import User +from django.http import HttpResponse +from django.shortcuts import get_object_or_404, get_list_or_404 from rest_framework import mixins, permissions, status, viewsets from rest_framework.authtoken.models import Token @@ -12,7 +14,7 @@ from rest_framework.decorators import api_view, permission_classes from rest_framework.response import Response from . import models, serializers -from authserver.settings import PROTOSPACE_LOGIN_PAGE +from authserver.settings import PROTOSPACE_LOGIN_PAGE, FIRMWARE_VERSION_MAGIC LOG_DIRECTORY = '/var/log/pslockout' VALID_TIME = 1000000000 @@ -44,9 +46,11 @@ class ToolViewSet(viewsets.ModelViewSet): class ToolDataViewSet(viewsets.ViewSet): def list(self, request): - objects = models.Category.objects.all().order_by('id') - serializer = serializers.ToolDataSerializer(objects, many=True, context={'request': request}) - return Response({'categories': serializer.data}) + category_objects = models.Category.objects.all().order_by('id') + categories = serializers.CategoryToolSerializer(category_objects, many=True, context={'request': request}) + course_objects = models.Course.objects.all().order_by('id') + courses = serializers.CourseSerializer(course_objects, many=True, context={'request': request}) + return Response({'categories': categories.data, 'courses': courses.data}) class ProfileViewSet( mixins.RetrieveModelMixin, @@ -64,10 +68,17 @@ class UserViewSet(viewsets.ReadOnlyModelViewSet): def get_queryset(self): return User.objects.filter(username=self.request.user) -@api_view(["POST"]) +class FirmwareViewSet(viewsets.ModelViewSet): + queryset = models.Firmware.objects.all().order_by('-version') + serializer_class = serializers.FirmwareSerializer + permission_classes = (IsLockoutAdmin,) + lookup_field='version' + http_method_names = ['get', 'post', 'head', 'delete', 'options'] + +@api_view(['POST']) def login(request): - username = request.data.get("username").lower() - password = request.data.get("password") + username = request.data.get('username').lower() + password = request.data.get('password') if username is None or password is None: return Response({'error': 'Please provide both username and password'}, status=status.HTTP_400_BAD_REQUEST) @@ -90,14 +101,15 @@ def login(request): return Response({'token': token.key}, status=status.HTTP_200_OK) -@api_view(["GET"]) +@api_view(['GET']) def cards(request, mac): - cards = models.Card.objects.all().filter(profile__courses__tools__mac=mac) + tool = get_object_or_404(models.Tool, mac=mac) + cards = models.Card.objects.all().filter(profile__courses__tools=tool) card_numbers = [card.number for card in cards] return Response(','.join(card_numbers), status=status.HTTP_200_OK) -@api_view(["PUT"]) +@api_view(['PUT']) @permission_classes((IsLockoutAdmin,)) def update_cards(request): data = request.data @@ -124,7 +136,7 @@ def update_cards(request): return Response({'updated': updated_count}, status=status.HTTP_200_OK) EVENTS = [ - 'LOG_BOOT_UP - Booted up =============================================', + 'LOG_BOOT_UP - =========== Booted up, version: ', 'LOG_INIT_COMPLETE - Initialization completed', 'LOG_WIFI_CONNECTED - Wifi connected', 'LOG_WIFI_DISCONNECTED - Wifi disconnected', @@ -142,14 +154,17 @@ EVENTS = [ 'LOG_CARD_GOOD_READ - Successful read from card: ', 'LOG_CARD_ACCEPTED - Accepted card: ', 'LOG_CARD_DENIED - Denied card: ', + 'LOG_UPDATE_FAILED - Firmware update failed, code: ', ] -@api_view(["POST"]) +@api_view(['POST']) def infolog(request, mac): entries_processed = 0 oldest_valid_log_time = time.time() + tool = get_object_or_404(models.Tool, mac=mac) + encoded_log = request.data.get('log') if encoded_log: decoded_log = base64.b64decode(encoded_log) @@ -178,9 +193,43 @@ def infolog(request, mac): entries_processed += 1 log_file.write(entry_string + '\n') + version = str(get_object_or_404(models.Firmware, tools=tool)) + version_string = '{} {} {}'.format(FIRMWARE_VERSION_MAGIC, version, FIRMWARE_VERSION_MAGIC) + response_object = { - 'unixTime': int(time.time()), 'processed': entries_processed, + 'unixTime': int(time.time()), + 'version': version_string, } return Response(response_object, status=status.HTTP_200_OK) + +@api_view(['GET']) +def update(request, mac): + tool = get_object_or_404(models.Tool, mac=mac) + firmware = get_object_or_404(models.Firmware, tools=tool) + + response = HttpResponse(firmware.binary, content_type='text/plain') + response['Content-Disposition'] = 'attachment; filename=firmware_{}.bin'.format(firmware.version) + return response + +@api_view(['PUT']) +@permission_classes((permissions.IsAuthenticated,)) +def select_courses(request): + courses = request.data.get('courses') + if courses is None: + return Response({'error': 'Please provide a list of course slugs'}, + status=status.HTTP_400_BAD_REQUEST) + + profile = get_object_or_404(models.Profile, user=request.user) + + if profile.courses.count() or profile.selected_courses: + return Response({'error': 'Please provide a list of course slugs'}, + status=status.HTTP_400_BAD_REQUEST) + + course_objects = get_list_or_404(models.Course, slug__in=courses) + profile.courses.set(course_objects) + profile.selected_courses = True + profile.save() + + return Response({'updated': len(courses)}, status=status.HTTP_200_OK) diff --git a/authserver/authserver/settings.py b/authserver/authserver/settings.py index e6e3586..eb8c463 100644 --- a/authserver/authserver/settings.py +++ b/authserver/authserver/settings.py @@ -128,3 +128,5 @@ MEDIA_ROOT = os.path.join(BASE_DIR, 'media') MEDIA_URL = '/media/' PROTOSPACE_LOGIN_PAGE = 'https://my.protospace.ca/login' + +FIRMWARE_VERSION_MAGIC = 'MRWIZARD' diff --git a/authserver/authserver/urls.py b/authserver/authserver/urls.py index 2e30428..5644268 100644 --- a/authserver/authserver/urls.py +++ b/authserver/authserver/urls.py @@ -29,6 +29,7 @@ router.register(r'course', views.CourseViewSet, 'course') router.register(r'tooldata', views.ToolDataViewSet, 'tooldata') router.register(r'profile', views.ProfileViewSet) router.register(r'user', views.UserViewSet, 'user') +router.register(r'firmware', views.FirmwareViewSet, 'firmware') urlpatterns = [ url(r'^', include(router.urls)), @@ -38,6 +39,8 @@ urlpatterns = [ url(r'^cards/(?P.*)/', views.cards), url(r'^update-cards/', views.update_cards), url(r'^infolog/(?P.*)/', views.infolog), + url(r'^update/(?P.*)/', views.update), + url(r'^select-courses/', views.select_courses), ] if settings.DEBUG is True: