feat: add --check-crop flag to verify image cropping
Co-authored-by: aider (gemini/gemini-2.5-pro-preview-05-06) <aider@aider.chat>
This commit is contained in:
37
train.py
37
train.py
@@ -5,11 +5,41 @@ from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
|
|||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
import os
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
from model import (CropLowerRightTriangle, GarageDoorCNN, TRIANGLE_CROP_WIDTH,
|
||||||
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
TRIANGLE_CROP_HEIGHT, RESIZE_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
def check_crop():
|
||||||
|
"""Saves a sample cropped image for debugging purposes and exits."""
|
||||||
|
# Find a sample image from your dataset
|
||||||
|
SAMPLE_IMAGE_DIR = 'data/labelled/open'
|
||||||
|
if not os.path.isdir(SAMPLE_IMAGE_DIR) or not os.listdir(SAMPLE_IMAGE_DIR):
|
||||||
|
print(f"Error: Cannot find sample image in '{SAMPLE_IMAGE_DIR}'.")
|
||||||
|
print("Please ensure the directory exists and contains images.")
|
||||||
|
return
|
||||||
|
|
||||||
|
sample_image_name = os.listdir(SAMPLE_IMAGE_DIR)[0]
|
||||||
|
sample_image_path = os.path.join(SAMPLE_IMAGE_DIR, sample_image_name)
|
||||||
|
|
||||||
|
print(f"Creating debug crop from image: {sample_image_path}")
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
image = Image.open(sample_image_path)
|
||||||
|
|
||||||
|
# Create the transform
|
||||||
|
cropper = CropLowerRightTriangle(triangle_width=TRIANGLE_CROP_WIDTH, triangle_height=TRIANGLE_CROP_HEIGHT)
|
||||||
|
|
||||||
|
# Apply the transform
|
||||||
|
cropped_image = cropper(image)
|
||||||
|
|
||||||
|
# Save the result
|
||||||
|
output_path = "cropped_debug_output.png"
|
||||||
|
cropped_image.save(output_path)
|
||||||
|
print(f"Debug image saved to '{output_path}'.")
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# --- Hyperparameters and Configuration ---
|
# --- Hyperparameters and Configuration ---
|
||||||
DATA_DIR = 'data/labelled'
|
DATA_DIR = 'data/labelled'
|
||||||
@@ -104,6 +134,13 @@ def train_model():
|
|||||||
print(f"Model saved to {MODEL_SAVE_PATH}")
|
print(f"Model saved to {MODEL_SAVE_PATH}")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description="Train a CNN for garage door detection or check the image crop.")
|
||||||
|
parser.add_argument('--check-crop', action='store_true', help='Save a sample cropped image and exit.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.check_crop:
|
||||||
|
check_crop()
|
||||||
|
else:
|
||||||
# Check if data directory exists
|
# Check if data directory exists
|
||||||
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'):
|
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'):
|
||||||
print("Error: Data directories 'data/open' and 'data/closed' not found.")
|
print("Error: Data directories 'data/open' and 'data/closed' not found.")
|
||||||
|
Reference in New Issue
Block a user