423 lines
14 KiB
Python
423 lines
14 KiB
Python
"""
|
|
Rowing Machine Display Classifier
|
|
==================================
|
|
Binary classifier: 1 = rowing machine display, 0 = not rowing machine.
|
|
|
|
Two modes:
|
|
1. Feature-based (works immediately, no training needed)
|
|
2. CNN-based with transfer learning (needs training data)
|
|
|
|
Usage:
|
|
# Predict with feature-based classifier (no training needed)
|
|
python classifier.py predict --image path/to/image.jpg
|
|
|
|
# Organize training data, then train CNN
|
|
python classifier.py train --data-dir data/
|
|
python classifier.py predict --image path/to/image.jpg --model cnn
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import json
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from PIL import Image, ImageStat, ImageFilter
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Feature-based classifier (works out of the box, no GPU needed)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def extract_features(image_path: str) -> dict:
|
|
"""Extract hand-crafted features that distinguish rowing displays."""
|
|
img = Image.open(image_path).convert("RGB")
|
|
gray = img.convert("L")
|
|
|
|
# Resize for consistent analysis
|
|
gray_resized = gray.resize((256, 256))
|
|
img_resized = img.resize((256, 256))
|
|
|
|
pixels = np.array(gray_resized, dtype=np.float64)
|
|
color_pixels = np.array(img_resized, dtype=np.float64)
|
|
|
|
features = {}
|
|
|
|
# 1. Contrast: rowing displays have high contrast (dark text on light bg)
|
|
features["std_dev"] = float(np.std(pixels))
|
|
|
|
# 2. Bimodality: displays tend toward two clusters (text vs background)
|
|
hist, _ = np.histogram(pixels, bins=32, range=(0, 256))
|
|
hist_norm = hist / hist.sum()
|
|
features["entropy"] = float(
|
|
-np.sum(hist_norm[hist_norm > 0] * np.log2(hist_norm[hist_norm > 0]))
|
|
)
|
|
|
|
# 3. Edge density: text/numbers create lots of edges
|
|
edges = gray_resized.filter(ImageFilter.FIND_EDGES)
|
|
edge_pixels = np.array(edges, dtype=np.float64)
|
|
features["edge_density"] = float(np.mean(edge_pixels > 30))
|
|
|
|
# 4. Horizontal line features: displays have horizontal separators
|
|
sobel_h = gray_resized.filter(
|
|
ImageFilter.Kernel(
|
|
size=(3, 3), kernel=[-1, -2, -1, 0, 0, 0, 1, 2, 1], scale=1, offset=128
|
|
)
|
|
)
|
|
sobel_pixels = np.abs(np.array(sobel_h, dtype=np.float64) - 128)
|
|
features["h_line_strength"] = float(np.mean(sobel_pixels > 20))
|
|
|
|
# 5. Color saturation: rowing displays are typically low-saturation
|
|
r, g, b = color_pixels[:, :, 0], color_pixels[:, :, 1], color_pixels[:, :, 2]
|
|
max_c = np.maximum(np.maximum(r, g), b)
|
|
min_c = np.minimum(np.minimum(r, g), b)
|
|
saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0)
|
|
features["mean_saturation"] = float(np.mean(saturation))
|
|
|
|
# 6. Dark pixel ratio: displays have significant dark regions (text)
|
|
features["dark_pixel_ratio"] = float(np.mean(pixels < 80))
|
|
|
|
# 7. Bright pixel ratio: displays have bright background regions
|
|
features["bright_pixel_ratio"] = float(np.mean(pixels > 180))
|
|
|
|
# 8. Texture uniformity via local variance
|
|
blurred = np.array(
|
|
gray_resized.filter(ImageFilter.GaussianBlur(5)), dtype=np.float64
|
|
)
|
|
local_var = np.mean((pixels - blurred) ** 2)
|
|
features["local_variance"] = float(local_var)
|
|
|
|
return features
|
|
|
|
|
|
def feature_based_predict(image_path: str, verbose: bool = False) -> tuple[int, float]:
|
|
"""
|
|
Predict using hand-crafted features and a rule-based scorer.
|
|
Returns (label, confidence).
|
|
"""
|
|
feats = extract_features(image_path)
|
|
|
|
if verbose:
|
|
print("\n Feature values:")
|
|
for k, v in feats.items():
|
|
print(f" {k:>20s}: {v:.4f}")
|
|
|
|
score = 0.0
|
|
|
|
# Rowing displays: high contrast
|
|
if feats["std_dev"] > 50:
|
|
score += 0.15
|
|
if feats["std_dev"] > 70:
|
|
score += 0.10
|
|
|
|
# High edge density (text/numbers)
|
|
if feats["edge_density"] > 0.08:
|
|
score += 0.15
|
|
if feats["edge_density"] > 0.15:
|
|
score += 0.10
|
|
|
|
# Horizontal lines (separators between rows of data)
|
|
if feats["h_line_strength"] > 0.06:
|
|
score += 0.10
|
|
|
|
# Low saturation (monochrome-ish displays)
|
|
if feats["mean_saturation"] < 0.15:
|
|
score += 0.10
|
|
|
|
# Bimodal histogram (text vs background)
|
|
if feats["entropy"] < 3.8:
|
|
score += 0.10
|
|
|
|
# Has both dark and bright regions
|
|
if feats["dark_pixel_ratio"] > 0.15 and feats["bright_pixel_ratio"] > 0.15:
|
|
score += 0.15
|
|
|
|
# High local variance = structured content
|
|
if feats["local_variance"] > 200:
|
|
score += 0.10
|
|
|
|
score = min(score, 1.0)
|
|
label = 1 if score >= 0.45 else 0
|
|
confidence = score if label == 1 else 1.0 - score
|
|
|
|
return label, confidence
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CNN-based classifier (requires training)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def get_cnn_model():
|
|
"""Build a simple CNN for binary classification."""
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
except ImportError:
|
|
print(
|
|
"Error: PyTorch required for CNN mode. Install with: pip install torch torchvision"
|
|
)
|
|
sys.exit(1)
|
|
|
|
class RowingCNN(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 32, 3, padding=1),
|
|
nn.BatchNorm2d(32),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(32, 64, 3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(64, 128, 3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2),
|
|
nn.Conv2d(128, 256, 3, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((4, 4)),
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(256 * 4 * 4, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(128, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.classifier(self.features(x))
|
|
|
|
return RowingCNN()
|
|
|
|
|
|
def train_cnn(
|
|
data_dir: str,
|
|
epochs: int = 20,
|
|
lr: float = 1e-3,
|
|
save_path: str = "screen_classifier_model.pth",
|
|
):
|
|
"""
|
|
Train the CNN. Expects data_dir with structure:
|
|
data_dir/
|
|
train/
|
|
0/ (non-rowing images)
|
|
1/ (rowing images)
|
|
val/ (optional, same structure)
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets, transforms
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Training on: {device}")
|
|
|
|
transform_train = transforms.Compose(
|
|
[
|
|
transforms.Resize((224, 224)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomRotation(10),
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
transform_val = transforms.Compose(
|
|
[
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
|
|
train_dir = os.path.join(data_dir, "train")
|
|
train_ds = datasets.ImageFolder(train_dir, transform=transform_train)
|
|
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
|
|
print(f"Training samples: {len(train_ds)} Classes: {train_ds.classes}")
|
|
|
|
val_loader = None
|
|
val_dir = os.path.join(data_dir, "val")
|
|
if os.path.isdir(val_dir):
|
|
val_ds = datasets.ImageFolder(val_dir, transform=transform_val)
|
|
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)
|
|
print(f"Validation samples: {len(val_ds)}")
|
|
|
|
model = get_cnn_model().to(device)
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
|
|
|
|
best_acc = 0.0
|
|
for epoch in range(epochs):
|
|
model.train()
|
|
running_loss, correct, total = 0.0, 0, 0
|
|
|
|
for inputs, labels in train_loader:
|
|
inputs, labels = inputs.to(device), labels.float().to(device)
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs).squeeze(1)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
preds = (torch.sigmoid(outputs) > 0.5).long()
|
|
correct += (preds == labels.long()).sum().item()
|
|
total += labels.size(0)
|
|
|
|
scheduler.step()
|
|
train_acc = correct / total
|
|
avg_loss = running_loss / total
|
|
line = f" Epoch {epoch + 1:>3d}/{epochs} loss={avg_loss:.4f} train_acc={train_acc:.3f}"
|
|
|
|
if val_loader:
|
|
model.eval()
|
|
val_correct, val_total = 0, 0
|
|
with torch.no_grad():
|
|
for inputs, labels in val_loader:
|
|
inputs, labels = inputs.to(device), labels.float().to(device)
|
|
outputs = model(inputs).squeeze(1)
|
|
preds = (torch.sigmoid(outputs) > 0.5).long()
|
|
val_correct += (preds == labels.long()).sum().item()
|
|
val_total += labels.size(0)
|
|
val_acc = val_correct / val_total
|
|
line += f" val_acc={val_acc:.3f}"
|
|
if val_acc > best_acc:
|
|
best_acc = val_acc
|
|
torch.save(model.state_dict(), save_path)
|
|
line += " *saved*"
|
|
else:
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
print(line)
|
|
|
|
print(f"\nModel saved to {save_path}")
|
|
|
|
|
|
def cnn_predict(image_path: str, model_path: str = "model.pth") -> tuple[int, float]:
|
|
"""Predict using the trained CNN."""
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = get_cnn_model()
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
|
|
img = Image.open(image_path).convert("RGB")
|
|
tensor = transform(img).unsqueeze(0).to(device)
|
|
|
|
with torch.no_grad():
|
|
output = model(tensor).squeeze()
|
|
prob = torch.sigmoid(output).item()
|
|
|
|
label = 1 if prob > 0.5 else 0
|
|
confidence = prob if label == 1 else 1.0 - prob
|
|
return label, confidence
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Rowing Machine Display Classifier")
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# --- predict ---
|
|
p_pred = sub.add_parser("predict", help="Classify an image or directory of images")
|
|
p_pred.add_argument("--image", help="Path to image file")
|
|
p_pred.add_argument("--dir", help="Path to directory of images")
|
|
p_pred.add_argument(
|
|
"--model",
|
|
choices=["features", "cnn"],
|
|
default="features",
|
|
help="Which classifier to use (default: features)",
|
|
)
|
|
p_pred.add_argument("--model-path", default="model.pth", help="Path to CNN weights")
|
|
p_pred.add_argument("--verbose", "-v", action="store_true")
|
|
|
|
# --- train ---
|
|
p_train = sub.add_parser("train", help="Train the CNN classifier")
|
|
p_train.add_argument("--data-dir", required=True, help="Root data directory")
|
|
p_train.add_argument("--epochs", type=int, default=20)
|
|
p_train.add_argument("--lr", type=float, default=1e-3)
|
|
p_train.add_argument("--save", default="model.pth", help="Where to save weights")
|
|
|
|
# --- extract ---
|
|
p_feat = sub.add_parser("extract", help="Print extracted features for an image")
|
|
p_feat.add_argument("--image", required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "predict":
|
|
if not args.image and not args.dir:
|
|
parser.error("predict requires --image or --dir")
|
|
if args.image and args.dir:
|
|
parser.error("--image and --dir are mutually exclusive")
|
|
|
|
# Build list of image paths
|
|
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff"}
|
|
if args.dir:
|
|
dir_path = Path(args.dir)
|
|
if not dir_path.is_dir():
|
|
print(f"Error: {args.dir} is not a directory")
|
|
sys.exit(1)
|
|
image_paths = sorted(
|
|
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
|
|
)
|
|
if not image_paths:
|
|
print(f"No images found in {args.dir}")
|
|
sys.exit(1)
|
|
else:
|
|
image_paths = [Path(args.image)]
|
|
|
|
# Classify each image
|
|
rowing_count = 0
|
|
for img_path in image_paths:
|
|
if args.model == "features":
|
|
label, conf = feature_based_predict(str(img_path), verbose=args.verbose)
|
|
else:
|
|
label, conf = cnn_predict(str(img_path), args.model_path)
|
|
|
|
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
|
|
if args.dir:
|
|
print(f" {img_path.name} \u2192 {tag} (confidence={conf:.2f})")
|
|
else:
|
|
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
|
|
if label == 1:
|
|
rowing_count += 1
|
|
|
|
# Summary for directory mode
|
|
if args.dir:
|
|
total = len(image_paths)
|
|
not_rowing = total - rowing_count
|
|
print(
|
|
f"\n Summary: {total} images | {rowing_count} rowing | {not_rowing} not rowing"
|
|
)
|
|
|
|
elif args.command == "train":
|
|
train_cnn(args.data_dir, epochs=args.epochs, lr=args.lr, save_path=args.save)
|
|
|
|
elif args.command == "extract":
|
|
feats = extract_features(args.image)
|
|
print(json.dumps(feats, indent=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|