Crops to rowing machine screen - can be trained with optimize_crop.py and screen_classifier
This commit is contained in:
@@ -195,7 +195,10 @@ def get_cnn_model():
|
||||
|
||||
|
||||
def train_cnn(
|
||||
data_dir: str, epochs: int = 20, lr: float = 1e-3, save_path: str = "model.pth"
|
||||
data_dir: str,
|
||||
epochs: int = 20,
|
||||
lr: float = 1e-3,
|
||||
save_path: str = "screen_classifier_model.pth",
|
||||
):
|
||||
"""
|
||||
Train the CNN. Expects data_dir with structure:
|
||||
@@ -336,8 +339,9 @@ def main():
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# --- predict ---
|
||||
p_pred = sub.add_parser("predict", help="Classify an image")
|
||||
p_pred.add_argument("--image", required=True, help="Path to image file")
|
||||
p_pred = sub.add_parser("predict", help="Classify an image or directory of images")
|
||||
p_pred.add_argument("--image", help="Path to image file")
|
||||
p_pred.add_argument("--dir", help="Path to directory of images")
|
||||
p_pred.add_argument(
|
||||
"--model",
|
||||
choices=["features", "cnn"],
|
||||
@@ -361,13 +365,50 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "predict":
|
||||
if args.model == "features":
|
||||
label, conf = feature_based_predict(args.image, verbose=args.verbose)
|
||||
else:
|
||||
label, conf = cnn_predict(args.image, args.model_path)
|
||||
if not args.image and not args.dir:
|
||||
parser.error("predict requires --image or --dir")
|
||||
if args.image and args.dir:
|
||||
parser.error("--image and --dir are mutually exclusive")
|
||||
|
||||
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
|
||||
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
|
||||
# Build list of image paths
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff"}
|
||||
if args.dir:
|
||||
dir_path = Path(args.dir)
|
||||
if not dir_path.is_dir():
|
||||
print(f"Error: {args.dir} is not a directory")
|
||||
sys.exit(1)
|
||||
image_paths = sorted(
|
||||
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
|
||||
)
|
||||
if not image_paths:
|
||||
print(f"No images found in {args.dir}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
image_paths = [Path(args.image)]
|
||||
|
||||
# Classify each image
|
||||
rowing_count = 0
|
||||
for img_path in image_paths:
|
||||
if args.model == "features":
|
||||
label, conf = feature_based_predict(str(img_path), verbose=args.verbose)
|
||||
else:
|
||||
label, conf = cnn_predict(str(img_path), args.model_path)
|
||||
|
||||
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
|
||||
if args.dir:
|
||||
print(f" {img_path.name} \u2192 {tag} (confidence={conf:.2f})")
|
||||
else:
|
||||
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
|
||||
if label == 1:
|
||||
rowing_count += 1
|
||||
|
||||
# Summary for directory mode
|
||||
if args.dir:
|
||||
total = len(image_paths)
|
||||
not_rowing = total - rowing_count
|
||||
print(
|
||||
f"\n Summary: {total} images | {rowing_count} rowing | {not_rowing} not rowing"
|
||||
)
|
||||
|
||||
elif args.command == "train":
|
||||
train_cnn(args.data_dir, epochs=args.epochs, lr=args.lr, save_path=args.save)
|
||||
|
||||
Reference in New Issue
Block a user