feat: add --check-crop flag to verify image cropping

Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
2025-07-31 16:55:32 -06:00
parent c1d11ea3f7
commit d3a4bd7ce9

View File

@@ -5,11 +5,41 @@ from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms from torchvision import datasets, transforms
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import os import os
import argparse
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH, from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
TRIANGLE_CROP_HEIGHT, RESIZE_DIM) TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
def check_crop():
"""Saves a sample cropped image for debugging purposes and exits."""
# Find a sample image from your dataset
SAMPLE_IMAGE_DIR = 'data/labelled/open'
if not os.path.isdir(SAMPLE_IMAGE_DIR) or not os.listdir(SAMPLE_IMAGE_DIR):
print(f"Error: Cannot find sample image in '{SAMPLE_IMAGE_DIR}'.")
print("Please ensure the directory exists and contains images.")
return
sample_image_name = os.listdir(SAMPLE_IMAGE_DIR)[0]
sample_image_path = os.path.join(SAMPLE_IMAGE_DIR, sample_image_name)
print(f"Creating debug crop from image: {sample_image_path}")
# Load the image
image = Image.open(sample_image_path)
# Create the transform
cropper = CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT)
# Apply the transform
cropped_image = cropper(image)
# Save the result
output_path = "cropped_debug_output.png"
cropped_image.save(output_path)
print(f"Debug image saved to '{output_path}'.")
def train_model(): def train_model():
# --- Hyperparameters and Configuration --- # --- Hyperparameters and Configuration ---
DATA_DIR = 'data/labelled' DATA_DIR = 'data/labelled'
@@ -104,9 +134,16 @@ def train_model():
print(f"Model saved to {MODEL_SAVE_PATH}") print(f"Model saved to {MODEL_SAVE_PATH}")
if __name__ == '__main__': if __name__ == '__main__':
# Check if data directory exists parser = argparse.ArgumentParser(description="Train a CNN for garage door detection or check the image crop.")
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'): parser.add_argument('--check-crop', action='store_true', help='Save a sample cropped image and exit.')
print("Error: Data directories 'data/open' and 'data/closed' not found.") args = parser.parse_args()
print("Please create them and place your image snapshots inside.")
if args.check_crop:
check_crop()
else: else:
train_model() # Check if data directory exists
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'):
print("Error: Data directories 'data/open' and 'data/closed' not found.")
print("Please create them and place your image snapshots inside.")
else:
train_model()