refactor: centralize transform constants in model.py
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
7
train.py
7
train.py
@@ -6,7 +6,8 @@ from torchvision import datasets, transforms
|
||||
from PIL import Image, ImageDraw
|
||||
import os
|
||||
|
||||
from model import CropLowerRightTriangle, GarageDoorCNN
|
||||
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
||||
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
||||
|
||||
|
||||
def train_model():
|
||||
@@ -16,10 +17,6 @@ def train_model():
|
||||
NUM_EPOCHS = 10
|
||||
BATCH_SIZE = 32
|
||||
LEARNING_RATE = 0.001
|
||||
# For the custom crop transform. User can adjust these.
|
||||
TRIANGLE_CROP_WIDTH = 556
|
||||
TRIANGLE_CROP_HEIGHT = 1184
|
||||
RESIZE_DIM = 64 # Resize cropped image to this dimension (square)
|
||||
|
||||
# --- Data Preparation ---
|
||||
# Define transforms
|
||||
|
Reference in New Issue
Block a user