Files
rowing_stats/screen_classifier.py

382 lines
13 KiB
Python

"""
Rowing Machine Display Classifier
==================================
Binary classifier: 1 = rowing machine display, 0 = not rowing machine.
Two modes:
1. Feature-based (works immediately, no training needed)
2. CNN-based with transfer learning (needs training data)
Usage:
# Predict with feature-based classifier (no training needed)
python classifier.py predict --image path/to/image.jpg
# Organize training data, then train CNN
python classifier.py train --data-dir data/
python classifier.py predict --image path/to/image.jpg --model cnn
"""
import argparse
import os
import sys
import json
import numpy as np
from pathlib import Path
from PIL import Image, ImageStat, ImageFilter
# ---------------------------------------------------------------------------
# Feature-based classifier (works out of the box, no GPU needed)
# ---------------------------------------------------------------------------
def extract_features(image_path: str) -> dict:
"""Extract hand-crafted features that distinguish rowing displays."""
img = Image.open(image_path).convert("RGB")
gray = img.convert("L")
# Resize for consistent analysis
gray_resized = gray.resize((256, 256))
img_resized = img.resize((256, 256))
pixels = np.array(gray_resized, dtype=np.float64)
color_pixels = np.array(img_resized, dtype=np.float64)
features = {}
# 1. Contrast: rowing displays have high contrast (dark text on light bg)
features["std_dev"] = float(np.std(pixels))
# 2. Bimodality: displays tend toward two clusters (text vs background)
hist, _ = np.histogram(pixels, bins=32, range=(0, 256))
hist_norm = hist / hist.sum()
features["entropy"] = float(
-np.sum(hist_norm[hist_norm > 0] * np.log2(hist_norm[hist_norm > 0]))
)
# 3. Edge density: text/numbers create lots of edges
edges = gray_resized.filter(ImageFilter.FIND_EDGES)
edge_pixels = np.array(edges, dtype=np.float64)
features["edge_density"] = float(np.mean(edge_pixels > 30))
# 4. Horizontal line features: displays have horizontal separators
sobel_h = gray_resized.filter(
ImageFilter.Kernel(
size=(3, 3), kernel=[-1, -2, -1, 0, 0, 0, 1, 2, 1], scale=1, offset=128
)
)
sobel_pixels = np.abs(np.array(sobel_h, dtype=np.float64) - 128)
features["h_line_strength"] = float(np.mean(sobel_pixels > 20))
# 5. Color saturation: rowing displays are typically low-saturation
r, g, b = color_pixels[:, :, 0], color_pixels[:, :, 1], color_pixels[:, :, 2]
max_c = np.maximum(np.maximum(r, g), b)
min_c = np.minimum(np.minimum(r, g), b)
saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0)
features["mean_saturation"] = float(np.mean(saturation))
# 6. Dark pixel ratio: displays have significant dark regions (text)
features["dark_pixel_ratio"] = float(np.mean(pixels < 80))
# 7. Bright pixel ratio: displays have bright background regions
features["bright_pixel_ratio"] = float(np.mean(pixels > 180))
# 8. Texture uniformity via local variance
blurred = np.array(
gray_resized.filter(ImageFilter.GaussianBlur(5)), dtype=np.float64
)
local_var = np.mean((pixels - blurred) ** 2)
features["local_variance"] = float(local_var)
return features
def feature_based_predict(image_path: str, verbose: bool = False) -> tuple[int, float]:
"""
Predict using hand-crafted features and a rule-based scorer.
Returns (label, confidence).
"""
feats = extract_features(image_path)
if verbose:
print("\n Feature values:")
for k, v in feats.items():
print(f" {k:>20s}: {v:.4f}")
score = 0.0
# Rowing displays: high contrast
if feats["std_dev"] > 50:
score += 0.15
if feats["std_dev"] > 70:
score += 0.10
# High edge density (text/numbers)
if feats["edge_density"] > 0.08:
score += 0.15
if feats["edge_density"] > 0.15:
score += 0.10
# Horizontal lines (separators between rows of data)
if feats["h_line_strength"] > 0.06:
score += 0.10
# Low saturation (monochrome-ish displays)
if feats["mean_saturation"] < 0.15:
score += 0.10
# Bimodal histogram (text vs background)
if feats["entropy"] < 3.8:
score += 0.10
# Has both dark and bright regions
if feats["dark_pixel_ratio"] > 0.15 and feats["bright_pixel_ratio"] > 0.15:
score += 0.15
# High local variance = structured content
if feats["local_variance"] > 200:
score += 0.10
score = min(score, 1.0)
label = 1 if score >= 0.45 else 0
confidence = score if label == 1 else 1.0 - score
return label, confidence
# ---------------------------------------------------------------------------
# CNN-based classifier (requires training)
# ---------------------------------------------------------------------------
def get_cnn_model():
"""Build a simple CNN for binary classification."""
try:
import torch
import torch.nn as nn
except ImportError:
print(
"Error: PyTorch required for CNN mode. Install with: pip install torch torchvision"
)
sys.exit(1)
class RowingCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AdaptiveAvgPool2d((4, 4)),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 4 * 4, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 1),
)
def forward(self, x):
return self.classifier(self.features(x))
return RowingCNN()
def train_cnn(
data_dir: str, epochs: int = 20, lr: float = 1e-3, save_path: str = "model.pth"
):
"""
Train the CNN. Expects data_dir with structure:
data_dir/
train/
0/ (non-rowing images)
1/ (rowing images)
val/ (optional, same structure)
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
transform_train = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
transform_val = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
train_dir = os.path.join(data_dir, "train")
train_ds = datasets.ImageFolder(train_dir, transform=transform_train)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
print(f"Training samples: {len(train_ds)} Classes: {train_ds.classes}")
val_loader = None
val_dir = os.path.join(data_dir, "val")
if os.path.isdir(val_dir):
val_ds = datasets.ImageFolder(val_dir, transform=transform_val)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)
print(f"Validation samples: {len(val_ds)}")
model = get_cnn_model().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
best_acc = 0.0
for epoch in range(epochs):
model.train()
running_loss, correct, total = 0.0, 0, 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.float().to(device)
optimizer.zero_grad()
outputs = model(inputs).squeeze(1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
preds = (torch.sigmoid(outputs) > 0.5).long()
correct += (preds == labels.long()).sum().item()
total += labels.size(0)
scheduler.step()
train_acc = correct / total
avg_loss = running_loss / total
line = f" Epoch {epoch + 1:>3d}/{epochs} loss={avg_loss:.4f} train_acc={train_acc:.3f}"
if val_loader:
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.float().to(device)
outputs = model(inputs).squeeze(1)
preds = (torch.sigmoid(outputs) > 0.5).long()
val_correct += (preds == labels.long()).sum().item()
val_total += labels.size(0)
val_acc = val_correct / val_total
line += f" val_acc={val_acc:.3f}"
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), save_path)
line += " *saved*"
else:
torch.save(model.state_dict(), save_path)
print(line)
print(f"\nModel saved to {save_path}")
def cnn_predict(image_path: str, model_path: str = "model.pth") -> tuple[int, float]:
"""Predict using the trained CNN."""
import torch
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_cnn_model()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img = Image.open(image_path).convert("RGB")
tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor).squeeze()
prob = torch.sigmoid(output).item()
label = 1 if prob > 0.5 else 0
confidence = prob if label == 1 else 1.0 - prob
return label, confidence
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Rowing Machine Display Classifier")
sub = parser.add_subparsers(dest="command", required=True)
# --- predict ---
p_pred = sub.add_parser("predict", help="Classify an image")
p_pred.add_argument("--image", required=True, help="Path to image file")
p_pred.add_argument(
"--model",
choices=["features", "cnn"],
default="features",
help="Which classifier to use (default: features)",
)
p_pred.add_argument("--model-path", default="model.pth", help="Path to CNN weights")
p_pred.add_argument("--verbose", "-v", action="store_true")
# --- train ---
p_train = sub.add_parser("train", help="Train the CNN classifier")
p_train.add_argument("--data-dir", required=True, help="Root data directory")
p_train.add_argument("--epochs", type=int, default=20)
p_train.add_argument("--lr", type=float, default=1e-3)
p_train.add_argument("--save", default="model.pth", help="Where to save weights")
# --- extract ---
p_feat = sub.add_parser("extract", help="Print extracted features for an image")
p_feat.add_argument("--image", required=True)
args = parser.parse_args()
if args.command == "predict":
if args.model == "features":
label, conf = feature_based_predict(args.image, verbose=args.verbose)
else:
label, conf = cnn_predict(args.image, args.model_path)
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
elif args.command == "train":
train_cnn(args.data_dir, epochs=args.epochs, lr=args.lr, save_path=args.save)
elif args.command == "extract":
feats = extract_features(args.image)
print(json.dumps(feats, indent=2))
if __name__ == "__main__":
main()