refactor: centralize transform constants in model.py

Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
2025-07-31 16:52:23 -06:00
parent 078c893770
commit c1d11ea3f7
3 changed files with 11 additions and 10 deletions

View File

@@ -1,6 +1,13 @@
import torch.nn as nn
from PIL import Image, ImageDraw
# --- Model and Transform Constants ---
# For the custom crop transform. User can adjust these.
TRIANGLE_CROP_WIDTH = 556
TRIANGLE_CROP_HEIGHT = 1184
RESIZE_DIM = 64 # Resize cropped image to this dimension (square)
# Custom transform to crop a triangle from the lower right corner
class CropLowerRightTriangle(object):
"""

View File

@@ -5,7 +5,8 @@ from PIL import Image
import os
import shutil
from model import CropLowerRightTriangle, GarageDoorCNN
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
def sort_images():
# --- Configuration ---
@@ -13,10 +14,6 @@ def sort_images():
SOURCE_DIR = 'data/hourly_photos/'
DEST_DIR = 'data/sorted/open/'
# These must match the parameters used during training
TRIANGLE_CROP_WIDTH = 556
TRIANGLE_CROP_HEIGHT = 1184
RESIZE_DIM = 64
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
CLASS_NAMES = ['closed', 'open']

View File

@@ -6,7 +6,8 @@ from torchvision import datasets, transforms
from PIL import Image, ImageDraw
import os
from model import CropLowerRightTriangle, GarageDoorCNN
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
def train_model():
@@ -16,10 +17,6 @@ def train_model():
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001
# For the custom crop transform. User can adjust these.
TRIANGLE_CROP_WIDTH = 556
TRIANGLE_CROP_HEIGHT = 1184
RESIZE_DIM = 64 # Resize cropped image to this dimension (square)
# --- Data Preparation ---
# Define transforms