diff --git a/model.py b/model.py index 3da6be8..86d563f 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,13 @@ import torch.nn as nn from PIL import Image, ImageDraw +# --- Model and Transform Constants --- +# For the custom crop transform. User can adjust these. +TRIANGLE_CROP_WIDTH = 556 +TRIANGLE_CROP_HEIGHT = 1184 +RESIZE_DIM = 64 # Resize cropped image to this dimension (square) + + # Custom transform to crop a triangle from the lower right corner class CropLowerRightTriangle(object): """ diff --git a/sort.py b/sort.py index c42ba4c..47c46ab 100644 --- a/sort.py +++ b/sort.py @@ -5,7 +5,8 @@ from PIL import Image import os import shutil -from model import CropLowerRightTriangle, GarageDoorCNN +from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, + TRIANGLE_CROP_HEIGHT, RESIZE_DIM) def sort_images(): # --- Configuration --- @@ -13,10 +14,6 @@ def sort_images(): SOURCE_DIR = 'data/hourly_photos/' DEST_DIR = 'data/sorted/open/' - # These must match the parameters used during training - TRIANGLE_CROP_WIDTH = 556 - TRIANGLE_CROP_HEIGHT = 1184 - RESIZE_DIM = 64 # The classes are sorted alphabetically by ImageFolder: ['closed', 'open'] CLASS_NAMES = ['closed', 'open'] diff --git a/train.py b/train.py index a42caa7..a57c0f7 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,8 @@ from torchvision import datasets, transforms from PIL import Image, ImageDraw import os -from model import CropLowerRightTriangle, GarageDoorCNN +from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, + TRIANGLE_CROP_HEIGHT, RESIZE_DIM) def train_model(): @@ -16,10 +17,6 @@ def train_model(): NUM_EPOCHS = 10 BATCH_SIZE = 32 LEARNING_RATE = 0.001 - # For the custom crop transform. User can adjust these. - TRIANGLE_CROP_WIDTH = 556 - TRIANGLE_CROP_HEIGHT = 1184 - RESIZE_DIM = 64 # Resize cropped image to this dimension (square) # --- Data Preparation --- # Define transforms