From d3a4bd7ce9d2a595761d1a96bf92d0af38bcdd39 Mon Sep 17 00:00:00 2001 From: Tanner Collin Date: Thu, 31 Jul 2025 16:55:32 -0600 Subject: [PATCH] feat: add --check-crop flag to verify image cropping Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) --- train.py | 47 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index a57c0f7..15a73b7 100644 --- a/train.py +++ b/train.py @@ -5,11 +5,41 @@ from torch.utils.data import DataLoader, random_split, WeightedRandomSampler from torchvision import datasets, transforms from PIL import Image, ImageDraw import os +import argparse from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, TRIANGLE_CROP_HEIGHT, RESIZE_DIM) +def check_crop(): + """Saves a sample cropped image for debugging purposes and exits.""" + # Find a sample image from your dataset + SAMPLE_IMAGE_DIR = 'data/labelled/open' + if not os.path.isdir(SAMPLE_IMAGE_DIR) or not os.listdir(SAMPLE_IMAGE_DIR): + print(f"Error: Cannot find sample image in '{SAMPLE_IMAGE_DIR}'.") + print("Please ensure the directory exists and contains images.") + return + + sample_image_name = os.listdir(SAMPLE_IMAGE_DIR)[0] + sample_image_path = os.path.join(SAMPLE_IMAGE_DIR, sample_image_name) + + print(f"Creating debug crop from image: {sample_image_path}") + + # Load the image + image = Image.open(sample_image_path) + + # Create the transform + cropper = CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT) + + # Apply the transform + cropped_image = cropper(image) + + # Save the result + output_path = "cropped_debug_output.png" + cropped_image.save(output_path) + print(f"Debug image saved to '{output_path}'.") + + def train_model(): # --- Hyperparameters and Configuration --- DATA_DIR = 'data/labelled' @@ -104,9 +134,16 @@ def train_model(): print(f"Model saved to {MODEL_SAVE_PATH}") if __name__ == '__main__': - # Check if data directory exists - if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'): - print("Error: Data directories 'data/open' and 'data/closed' not found.") - print("Please create them and place your image snapshots inside.") + parser = argparse.ArgumentParser(description="Train a CNN for garage door detection or check the image crop.") + parser.add_argument('--check-crop', action='store_true', help='Save a sample cropped image and exit.') + args = parser.parse_args() + + if args.check_crop: + check_crop() else: - train_model() + # Check if data directory exists + if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'): + print("Error: Data directories 'data/open' and 'data/closed' not found.") + print("Please create them and place your image snapshots inside.") + else: + train_model()