refactor: Update data directory structure to 'data/labelled'
This commit is contained in:
4
train.py
4
train.py
@@ -78,7 +78,7 @@ class GarageDoorCNN(nn.Module):
|
|||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# --- Hyperparameters and Configuration ---
|
# --- Hyperparameters and Configuration ---
|
||||||
DATA_DIR = 'data'
|
DATA_DIR = 'data/labelled'
|
||||||
MODEL_SAVE_PATH = 'garage_door_cnn.pth'
|
MODEL_SAVE_PATH = 'garage_door_cnn.pth'
|
||||||
NUM_EPOCHS = 10
|
NUM_EPOCHS = 10
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
@@ -158,7 +158,7 @@ def train_model():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Check if data directory exists
|
# Check if data directory exists
|
||||||
if not os.path.isdir('data/open') or not os.path.isdir('data/closed'):
|
if not os.path.isdir('data/labelled/open') or not os.path.isdir('data/labelled/closed'):
|
||||||
print("Error: Data directories 'data/open' and 'data/closed' not found.")
|
print("Error: Data directories 'data/open' and 'data/closed' not found.")
|
||||||
print("Please create them and place your image snapshots inside.")
|
print("Please create them and place your image snapshots inside.")
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user