refactor: centralize transform constants in model.py
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
7
model.py
7
model.py
@@ -1,6 +1,13 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
# --- Model and Transform Constants ---
|
||||||
|
# For the custom crop transform. User can adjust these.
|
||||||
|
TRIANGLE_CROP_WIDTH = 556
|
||||||
|
TRIANGLE_CROP_HEIGHT = 1184
|
||||||
|
RESIZE_DIM = 64 # Resize cropped image to this dimension (square)
|
||||||
|
|
||||||
|
|
||||||
# Custom transform to crop a triangle from the lower right corner
|
# Custom transform to crop a triangle from the lower right corner
|
||||||
class CropLowerRightTriangle(object):
|
class CropLowerRightTriangle(object):
|
||||||
"""
|
"""
|
||||||
|
7
sort.py
7
sort.py
@@ -5,7 +5,8 @@ from PIL import Image
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from model import CropLowerRightTriangle, GarageDoorCNN
|
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
||||||
|
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
||||||
|
|
||||||
def sort_images():
|
def sort_images():
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
@@ -13,10 +14,6 @@ def sort_images():
|
|||||||
SOURCE_DIR = 'data/hourly_photos/'
|
SOURCE_DIR = 'data/hourly_photos/'
|
||||||
DEST_DIR = 'data/sorted/open/'
|
DEST_DIR = 'data/sorted/open/'
|
||||||
|
|
||||||
# These must match the parameters used during training
|
|
||||||
TRIANGLE_CROP_WIDTH = 556
|
|
||||||
TRIANGLE_CROP_HEIGHT = 1184
|
|
||||||
RESIZE_DIM = 64
|
|
||||||
|
|
||||||
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
|
# The classes are sorted alphabetically by ImageFolder: ['closed', 'open']
|
||||||
CLASS_NAMES = ['closed', 'open']
|
CLASS_NAMES = ['closed', 'open']
|
||||||
|
7
train.py
7
train.py
@@ -6,7 +6,8 @@ from torchvision import datasets, transforms
|
|||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from model import CropLowerRightTriangle, GarageDoorCNN
|
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
||||||
|
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
@@ -16,10 +17,6 @@ def train_model():
|
|||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 10
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
# For the custom crop transform. User can adjust these.
|
|
||||||
TRIANGLE_CROP_WIDTH = 556
|
|
||||||
TRIANGLE_CROP_HEIGHT = 1184
|
|
||||||
RESIZE_DIM = 64 # Resize cropped image to this dimension (square)
|
|
||||||
|
|
||||||
# --- Data Preparation ---
|
# --- Data Preparation ---
|
||||||
# Define transforms
|
# Define transforms
|
||||||
|
Reference in New Issue
Block a user