Crops to rowing machine screen - can be trained with optimize_crop.py and screen_classifier

This commit is contained in:
2026-03-16 13:46:02 +00:00
parent 2e386a4297
commit f0184319c6
4 changed files with 309 additions and 31 deletions

View File

@@ -195,7 +195,10 @@ def get_cnn_model():
def train_cnn(
data_dir: str, epochs: int = 20, lr: float = 1e-3, save_path: str = "model.pth"
data_dir: str,
epochs: int = 20,
lr: float = 1e-3,
save_path: str = "screen_classifier_model.pth",
):
"""
Train the CNN. Expects data_dir with structure:
@@ -336,8 +339,9 @@ def main():
sub = parser.add_subparsers(dest="command", required=True)
# --- predict ---
p_pred = sub.add_parser("predict", help="Classify an image")
p_pred.add_argument("--image", required=True, help="Path to image file")
p_pred = sub.add_parser("predict", help="Classify an image or directory of images")
p_pred.add_argument("--image", help="Path to image file")
p_pred.add_argument("--dir", help="Path to directory of images")
p_pred.add_argument(
"--model",
choices=["features", "cnn"],
@@ -361,13 +365,50 @@ def main():
args = parser.parse_args()
if args.command == "predict":
if args.model == "features":
label, conf = feature_based_predict(args.image, verbose=args.verbose)
else:
label, conf = cnn_predict(args.image, args.model_path)
if not args.image and not args.dir:
parser.error("predict requires --image or --dir")
if args.image and args.dir:
parser.error("--image and --dir are mutually exclusive")
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
# Build list of image paths
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff"}
if args.dir:
dir_path = Path(args.dir)
if not dir_path.is_dir():
print(f"Error: {args.dir} is not a directory")
sys.exit(1)
image_paths = sorted(
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
)
if not image_paths:
print(f"No images found in {args.dir}")
sys.exit(1)
else:
image_paths = [Path(args.image)]
# Classify each image
rowing_count = 0
for img_path in image_paths:
if args.model == "features":
label, conf = feature_based_predict(str(img_path), verbose=args.verbose)
else:
label, conf = cnn_predict(str(img_path), args.model_path)
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
if args.dir:
print(f" {img_path.name} \u2192 {tag} (confidence={conf:.2f})")
else:
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
if label == 1:
rowing_count += 1
# Summary for directory mode
if args.dir:
total = len(image_paths)
not_rowing = total - rowing_count
print(
f"\n Summary: {total} images | {rowing_count} rowing | {not_rowing} not rowing"
)
elif args.command == "train":
train_cnn(args.data_dir, epochs=args.epochs, lr=args.lr, save_path=args.save)