Make new file for each step of processing
This commit is contained in:
155
crop_to_screen.py
Normal file
155
crop_to_screen.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Crop Concept 2 PM5 rowing machine screens from photos using OpenCV.
|
||||
|
||||
Detection strategy:
|
||||
The LCD screen has HIGH internal edge density (text/numbers/lines)
|
||||
compared to other bright regions (windows, walls, lockers).
|
||||
We threshold at multiple brightness levels, filter by edge density,
|
||||
aspect ratio, and size, then pick the best match.
|
||||
|
||||
Usage:
|
||||
python crop_screens.py [input_dir] [output_dir]
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
import sys
|
||||
|
||||
|
||||
def find_screen(image):
|
||||
"""
|
||||
Detect the Concept 2 PM5 LCD screen region in the image.
|
||||
|
||||
Returns (x, y, w, h) bounding box or None if not found.
|
||||
"""
|
||||
h_img, w_img = image.shape[:2]
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Pre-compute edge map for internal-content scoring
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
edges = cv2.Canny(blurred, 50, 150)
|
||||
|
||||
candidates = []
|
||||
|
||||
# Sweep brightness thresholds — screen brightness varies by
|
||||
# lighting conditions (ranges from ~100 in dim gyms to ~200+)
|
||||
for thresh_val in range(120, 200, 10):
|
||||
_, thresh = cv2.threshold(gray, thresh_val, 255, cv2.THRESH_BINARY)
|
||||
kern = cv2.getStructuringElement(cv2.MORPH_RECT, (11, 11))
|
||||
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kern)
|
||||
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kern)
|
||||
|
||||
contours, _ = cv2.findContours(
|
||||
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
for cnt in contours:
|
||||
x, y, w, h = cv2.boundingRect(cnt)
|
||||
area = cv2.contourArea(cnt)
|
||||
rect_area = w * h
|
||||
if rect_area == 0:
|
||||
continue
|
||||
|
||||
# Size: screen is a small-to-medium portion of the photo
|
||||
area_ratio = rect_area / (h_img * w_img)
|
||||
if area_ratio < 0.005 or area_ratio > 0.12:
|
||||
continue
|
||||
|
||||
# Aspect ratio: LCD is roughly square (0.5 to 1.6)
|
||||
aspect = w / h
|
||||
if aspect < 0.5 or aspect > 1.6:
|
||||
continue
|
||||
|
||||
# Rectangularity
|
||||
rectangularity = area / rect_area
|
||||
if rectangularity < 0.4:
|
||||
continue
|
||||
|
||||
# KEY: edge density — LCD with text > 0.03, plain surfaces < 0.01
|
||||
roi_edges = edges[y : y + h, x : x + w]
|
||||
edge_density = np.sum(roi_edges > 0) / rect_area
|
||||
if edge_density < 0.03:
|
||||
continue
|
||||
|
||||
# Score: edge density * area * rectangularity
|
||||
# This favours text-rich regions that are large and well-shaped
|
||||
score = edge_density * area * rectangularity
|
||||
candidates.append((score, x, y, w, h))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
candidates.sort(key=lambda c: c[0], reverse=True)
|
||||
return candidates[0][1:]
|
||||
|
||||
|
||||
def crop_screen(image_path, output_path, padding=15):
|
||||
"""Load an image, find the screen, crop and save it."""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f" ERROR: Could not read {image_path}")
|
||||
return False
|
||||
|
||||
h_img, w_img = image.shape[:2]
|
||||
result = find_screen(image)
|
||||
|
||||
if result is None:
|
||||
print(f" SKIP: No screen detected in {os.path.basename(image_path)}")
|
||||
return False
|
||||
|
||||
x, y, w, h = result
|
||||
|
||||
# Add padding, clamped to image bounds
|
||||
x1 = max(0, x - padding)
|
||||
y1 = max(0, y - padding)
|
||||
x2 = min(w_img, x + w + padding)
|
||||
y2 = min(h_img, y + h + padding)
|
||||
|
||||
cropped = image[y1:y2, x1:x2]
|
||||
cv2.imwrite(output_path, cropped, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
print(
|
||||
f" OK: {os.path.basename(image_path)} -> {os.path.basename(output_path)} ({w}x{h})"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) >= 3:
|
||||
input_dir = sys.argv[1]
|
||||
output_dir = sys.argv[2]
|
||||
elif len(sys.argv) == 2:
|
||||
input_dir = sys.argv[1]
|
||||
output_dir = os.path.join(input_dir, "cropped")
|
||||
else:
|
||||
input_dir = "/mnt/user-data/uploads"
|
||||
output_dir = "/mnt/user-data/outputs"
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
images = sorted(
|
||||
glob.glob(os.path.join(input_dir, "*.JPEG"))
|
||||
+ glob.glob(os.path.join(input_dir, "*.jpeg"))
|
||||
+ glob.glob(os.path.join(input_dir, "*.jpg"))
|
||||
+ glob.glob(os.path.join(input_dir, "*.JPG"))
|
||||
)
|
||||
|
||||
if not images:
|
||||
print(f"No images found in {input_dir}")
|
||||
return
|
||||
|
||||
print(f"Found {len(images)} images in {input_dir}\n")
|
||||
|
||||
success = 0
|
||||
for img_path in images:
|
||||
name = os.path.splitext(os.path.basename(img_path))[0]
|
||||
out_path = os.path.join(output_dir, f"{name}_screen.jpg")
|
||||
if crop_screen(img_path, out_path):
|
||||
success += 1
|
||||
|
||||
print(f"\nDone: {success}/{len(images)} screens cropped -> {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,76 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "13389e33",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import extract_data as ed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "e5de5ac0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"imgs = ed.get_images(ed.PHOTOS_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "575fd8c9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "module 'extract_data' has no attribute 'plot_image'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m img = ed.convert_to_opencv_image(imgs[\u001b[32m0\u001b[39m])\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43med\u001b[49m\u001b[43m.\u001b[49m\u001b[43mplot_image\u001b[49m(img)\n",
|
||||
"\u001b[31mAttributeError\u001b[39m: module 'extract_data' has no attribute 'plot_image'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"img = ed.convert_to_opencv_image(imgs[0])\n",
|
||||
"ed.plot_image(img)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b8b7bebc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
168
extract_data.py
168
extract_data.py
@@ -1,168 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytesseract as tess
|
||||
from PIL import Image
|
||||
|
||||
PHOTOS_PATH = "./photos/"
|
||||
|
||||
# Get a list of images given a directory path
|
||||
def get_images(url: str):
|
||||
images = []
|
||||
for img_url in os.listdir(url):
|
||||
try:
|
||||
image = Image.open(os.path.join(url, img_url))
|
||||
images.append(image)
|
||||
except IOError:
|
||||
print(f"Error opening image: {img_url}")
|
||||
return images
|
||||
|
||||
|
||||
# Get the datetime taken from an image
|
||||
def get_datetime_taken(image: Image.Image) -> datetime | None:
|
||||
exif = image.getexif()
|
||||
if 306 in exif:
|
||||
return datetime.strptime(exif[306], "%Y:%m:%d %H:%M:%S")
|
||||
return None
|
||||
|
||||
|
||||
# Convert an image to OpenCV format
|
||||
def convert_to_opencv_image(img: Image.Image) -> np.ndarray:
|
||||
return cv.cvtColor(np.array(img), cv.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def order_points(pts):
|
||||
pts = pts.reshape(4, 2)
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)] # top-left
|
||||
rect[2] = pts[np.argmax(s)] # bottom-right
|
||||
|
||||
diff = np.diff(pts, axis=1)
|
||||
rect[1] = pts[np.argmin(diff)] # top-right
|
||||
rect[3] = pts[np.argmax(diff)] # bottom-left
|
||||
|
||||
return rect
|
||||
|
||||
|
||||
def is_closed_contour(cnt, eps=1.0):
|
||||
# Check area
|
||||
if cv.contourArea(cnt) == 0:
|
||||
return False
|
||||
# Check if first and last points are close
|
||||
return cv.norm(cnt[0][0] - cnt[-1][0]) < eps
|
||||
|
||||
|
||||
# Optimise the image for OCR
|
||||
def process_image(img: Image.Image):
|
||||
arr = convert_to_opencv_image(img)
|
||||
|
||||
# Blur the image for better edge (contour) detection
|
||||
blur = cv.GaussianBlur(arr, (7, 7), 0)
|
||||
edges = cv.Canny(blur, 50, 100)
|
||||
contours, hierarchy = cv.findContours(
|
||||
edges, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
# Filter contours for rectangles
|
||||
candidates = []
|
||||
img_area = arr.shape[0] * arr.shape[1]
|
||||
|
||||
for cnt in contours:
|
||||
rect = cv.minAreaRect(cnt)
|
||||
(center, (width, height), angle) = rect
|
||||
box = cv.boxPoints(rect)
|
||||
box_contour = box.reshape((-1, 1, 2))
|
||||
area = cv.contourArea(box_contour)
|
||||
if area < 0.01 * img_area:
|
||||
continue
|
||||
# Check the aspect ratio is reasonable
|
||||
aspect_ratio = width / float(height)
|
||||
|
||||
if 0.9 < aspect_ratio < 1.1:
|
||||
candidates.append(box_contour)
|
||||
|
||||
# Most likely rectangle will be the largest one
|
||||
if len(candidates) == 0:
|
||||
return None
|
||||
|
||||
cv.drawContours(arr, contours, -1, (0, 255, 0), 3)
|
||||
preview_image(arr)
|
||||
|
||||
display_contour = max(candidates, key=cv.contourArea)
|
||||
rect = order_points(display_contour)
|
||||
(w, h) = (400, 400)
|
||||
dst = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]], dtype="float32")
|
||||
mat = cv.getPerspectiveTransform(rect, dst)
|
||||
warped = cv.warpPerspective(arr, mat, (w, h))
|
||||
return warped
|
||||
|
||||
|
||||
# Get the text from an image using OCR
|
||||
def ocr_image(img: Image.Image) -> str:
|
||||
return None
|
||||
|
||||
|
||||
# Process OCR text output
|
||||
def process_ocr_text(text: str) -> str:
|
||||
return None
|
||||
|
||||
|
||||
# There are two gyms that I go to, one is the Peckham gym and the other is the Elephant and Castle gym.
|
||||
# You can tell which gym I went to by looking at the color of the wall.
|
||||
# If there is a green wall, its most likely the Peckham gym.
|
||||
# If there is a blue wall, its most likely the Elephant and Castle gym.
|
||||
def get_gym(image: Image.Image) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def preview_image(img: np.ndarray):
|
||||
cv.imshow("preview", img)
|
||||
cv.waitKey(0)
|
||||
cv.destroyAllWindows()
|
||||
|
||||
def plot_image(img, figsize=(6,6)):
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
ax.imshow(img)
|
||||
ax.axis("off")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def __main__():
|
||||
imgs = get_images(PHOTOS_PATH)
|
||||
|
||||
if not imgs:
|
||||
print("No images")
|
||||
return None
|
||||
|
||||
fail = []
|
||||
success = []
|
||||
for img in imgs:
|
||||
tst = process_image(img)
|
||||
if tst is None:
|
||||
fail.append(img.filename)
|
||||
continue
|
||||
|
||||
success.append(img.filename)
|
||||
print("success_len: ", len(success))
|
||||
print("fail_len: ", len(fail))
|
||||
|
||||
print("failed:")
|
||||
for x in fail:
|
||||
print(x)
|
||||
|
||||
print("success:")
|
||||
for x in success:
|
||||
print(x)
|
||||
|
||||
print("success_len: ", len(success))
|
||||
print("fail_len: ", len(fail))
|
||||
|
||||
return None
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.5 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.6 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 222 KiB |
381
screen_classifier.py
Normal file
381
screen_classifier.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Rowing Machine Display Classifier
|
||||
==================================
|
||||
Binary classifier: 1 = rowing machine display, 0 = not rowing machine.
|
||||
|
||||
Two modes:
|
||||
1. Feature-based (works immediately, no training needed)
|
||||
2. CNN-based with transfer learning (needs training data)
|
||||
|
||||
Usage:
|
||||
# Predict with feature-based classifier (no training needed)
|
||||
python classifier.py predict --image path/to/image.jpg
|
||||
|
||||
# Organize training data, then train CNN
|
||||
python classifier.py train --data-dir data/
|
||||
python classifier.py predict --image path/to/image.jpg --model cnn
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageStat, ImageFilter
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature-based classifier (works out of the box, no GPU needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_features(image_path: str) -> dict:
|
||||
"""Extract hand-crafted features that distinguish rowing displays."""
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
gray = img.convert("L")
|
||||
|
||||
# Resize for consistent analysis
|
||||
gray_resized = gray.resize((256, 256))
|
||||
img_resized = img.resize((256, 256))
|
||||
|
||||
pixels = np.array(gray_resized, dtype=np.float64)
|
||||
color_pixels = np.array(img_resized, dtype=np.float64)
|
||||
|
||||
features = {}
|
||||
|
||||
# 1. Contrast: rowing displays have high contrast (dark text on light bg)
|
||||
features["std_dev"] = float(np.std(pixels))
|
||||
|
||||
# 2. Bimodality: displays tend toward two clusters (text vs background)
|
||||
hist, _ = np.histogram(pixels, bins=32, range=(0, 256))
|
||||
hist_norm = hist / hist.sum()
|
||||
features["entropy"] = float(
|
||||
-np.sum(hist_norm[hist_norm > 0] * np.log2(hist_norm[hist_norm > 0]))
|
||||
)
|
||||
|
||||
# 3. Edge density: text/numbers create lots of edges
|
||||
edges = gray_resized.filter(ImageFilter.FIND_EDGES)
|
||||
edge_pixels = np.array(edges, dtype=np.float64)
|
||||
features["edge_density"] = float(np.mean(edge_pixels > 30))
|
||||
|
||||
# 4. Horizontal line features: displays have horizontal separators
|
||||
sobel_h = gray_resized.filter(
|
||||
ImageFilter.Kernel(
|
||||
size=(3, 3), kernel=[-1, -2, -1, 0, 0, 0, 1, 2, 1], scale=1, offset=128
|
||||
)
|
||||
)
|
||||
sobel_pixels = np.abs(np.array(sobel_h, dtype=np.float64) - 128)
|
||||
features["h_line_strength"] = float(np.mean(sobel_pixels > 20))
|
||||
|
||||
# 5. Color saturation: rowing displays are typically low-saturation
|
||||
r, g, b = color_pixels[:, :, 0], color_pixels[:, :, 1], color_pixels[:, :, 2]
|
||||
max_c = np.maximum(np.maximum(r, g), b)
|
||||
min_c = np.minimum(np.minimum(r, g), b)
|
||||
saturation = np.where(max_c > 0, (max_c - min_c) / max_c, 0)
|
||||
features["mean_saturation"] = float(np.mean(saturation))
|
||||
|
||||
# 6. Dark pixel ratio: displays have significant dark regions (text)
|
||||
features["dark_pixel_ratio"] = float(np.mean(pixels < 80))
|
||||
|
||||
# 7. Bright pixel ratio: displays have bright background regions
|
||||
features["bright_pixel_ratio"] = float(np.mean(pixels > 180))
|
||||
|
||||
# 8. Texture uniformity via local variance
|
||||
blurred = np.array(
|
||||
gray_resized.filter(ImageFilter.GaussianBlur(5)), dtype=np.float64
|
||||
)
|
||||
local_var = np.mean((pixels - blurred) ** 2)
|
||||
features["local_variance"] = float(local_var)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def feature_based_predict(image_path: str, verbose: bool = False) -> tuple[int, float]:
|
||||
"""
|
||||
Predict using hand-crafted features and a rule-based scorer.
|
||||
Returns (label, confidence).
|
||||
"""
|
||||
feats = extract_features(image_path)
|
||||
|
||||
if verbose:
|
||||
print("\n Feature values:")
|
||||
for k, v in feats.items():
|
||||
print(f" {k:>20s}: {v:.4f}")
|
||||
|
||||
score = 0.0
|
||||
|
||||
# Rowing displays: high contrast
|
||||
if feats["std_dev"] > 50:
|
||||
score += 0.15
|
||||
if feats["std_dev"] > 70:
|
||||
score += 0.10
|
||||
|
||||
# High edge density (text/numbers)
|
||||
if feats["edge_density"] > 0.08:
|
||||
score += 0.15
|
||||
if feats["edge_density"] > 0.15:
|
||||
score += 0.10
|
||||
|
||||
# Horizontal lines (separators between rows of data)
|
||||
if feats["h_line_strength"] > 0.06:
|
||||
score += 0.10
|
||||
|
||||
# Low saturation (monochrome-ish displays)
|
||||
if feats["mean_saturation"] < 0.15:
|
||||
score += 0.10
|
||||
|
||||
# Bimodal histogram (text vs background)
|
||||
if feats["entropy"] < 3.8:
|
||||
score += 0.10
|
||||
|
||||
# Has both dark and bright regions
|
||||
if feats["dark_pixel_ratio"] > 0.15 and feats["bright_pixel_ratio"] > 0.15:
|
||||
score += 0.15
|
||||
|
||||
# High local variance = structured content
|
||||
if feats["local_variance"] > 200:
|
||||
score += 0.10
|
||||
|
||||
score = min(score, 1.0)
|
||||
label = 1 if score >= 0.45 else 0
|
||||
confidence = score if label == 1 else 1.0 - score
|
||||
|
||||
return label, confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CNN-based classifier (requires training)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_cnn_model():
|
||||
"""Build a simple CNN for binary classification."""
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: PyTorch required for CNN mode. Install with: pip install torch torchvision"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
class RowingCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(32, 64, 3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(64, 128, 3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((4, 4)),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(256 * 4 * 4, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(128, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.classifier(self.features(x))
|
||||
|
||||
return RowingCNN()
|
||||
|
||||
|
||||
def train_cnn(
|
||||
data_dir: str, epochs: int = 20, lr: float = 1e-3, save_path: str = "model.pth"
|
||||
):
|
||||
"""
|
||||
Train the CNN. Expects data_dir with structure:
|
||||
data_dir/
|
||||
train/
|
||||
0/ (non-rowing images)
|
||||
1/ (rowing images)
|
||||
val/ (optional, same structure)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Training on: {device}")
|
||||
|
||||
transform_train = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(10),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
transform_val = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
train_dir = os.path.join(data_dir, "train")
|
||||
train_ds = datasets.ImageFolder(train_dir, transform=transform_train)
|
||||
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
|
||||
print(f"Training samples: {len(train_ds)} Classes: {train_ds.classes}")
|
||||
|
||||
val_loader = None
|
||||
val_dir = os.path.join(data_dir, "val")
|
||||
if os.path.isdir(val_dir):
|
||||
val_ds = datasets.ImageFolder(val_dir, transform=transform_val)
|
||||
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)
|
||||
print(f"Validation samples: {len(val_ds)}")
|
||||
|
||||
model = get_cnn_model().to(device)
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
|
||||
|
||||
best_acc = 0.0
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
running_loss, correct, total = 0.0, 0, 0
|
||||
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.float().to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs).squeeze(1)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
preds = (torch.sigmoid(outputs) > 0.5).long()
|
||||
correct += (preds == labels.long()).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
scheduler.step()
|
||||
train_acc = correct / total
|
||||
avg_loss = running_loss / total
|
||||
line = f" Epoch {epoch + 1:>3d}/{epochs} loss={avg_loss:.4f} train_acc={train_acc:.3f}"
|
||||
|
||||
if val_loader:
|
||||
model.eval()
|
||||
val_correct, val_total = 0, 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_loader:
|
||||
inputs, labels = inputs.to(device), labels.float().to(device)
|
||||
outputs = model(inputs).squeeze(1)
|
||||
preds = (torch.sigmoid(outputs) > 0.5).long()
|
||||
val_correct += (preds == labels.long()).sum().item()
|
||||
val_total += labels.size(0)
|
||||
val_acc = val_correct / val_total
|
||||
line += f" val_acc={val_acc:.3f}"
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
torch.save(model.state_dict(), save_path)
|
||||
line += " *saved*"
|
||||
else:
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
||||
print(line)
|
||||
|
||||
print(f"\nModel saved to {save_path}")
|
||||
|
||||
|
||||
def cnn_predict(image_path: str, model_path: str = "model.pth") -> tuple[int, float]:
|
||||
"""Predict using the trained CNN."""
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = get_cnn_model()
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
tensor = transform(img).unsqueeze(0).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(tensor).squeeze()
|
||||
prob = torch.sigmoid(output).item()
|
||||
|
||||
label = 1 if prob > 0.5 else 0
|
||||
confidence = prob if label == 1 else 1.0 - prob
|
||||
return label, confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Rowing Machine Display Classifier")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# --- predict ---
|
||||
p_pred = sub.add_parser("predict", help="Classify an image")
|
||||
p_pred.add_argument("--image", required=True, help="Path to image file")
|
||||
p_pred.add_argument(
|
||||
"--model",
|
||||
choices=["features", "cnn"],
|
||||
default="features",
|
||||
help="Which classifier to use (default: features)",
|
||||
)
|
||||
p_pred.add_argument("--model-path", default="model.pth", help="Path to CNN weights")
|
||||
p_pred.add_argument("--verbose", "-v", action="store_true")
|
||||
|
||||
# --- train ---
|
||||
p_train = sub.add_parser("train", help="Train the CNN classifier")
|
||||
p_train.add_argument("--data-dir", required=True, help="Root data directory")
|
||||
p_train.add_argument("--epochs", type=int, default=20)
|
||||
p_train.add_argument("--lr", type=float, default=1e-3)
|
||||
p_train.add_argument("--save", default="model.pth", help="Where to save weights")
|
||||
|
||||
# --- extract ---
|
||||
p_feat = sub.add_parser("extract", help="Print extracted features for an image")
|
||||
p_feat.add_argument("--image", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "predict":
|
||||
if args.model == "features":
|
||||
label, conf = feature_based_predict(args.image, verbose=args.verbose)
|
||||
else:
|
||||
label, conf = cnn_predict(args.image, args.model_path)
|
||||
|
||||
tag = "ROWING MACHINE" if label == 1 else "NOT ROWING MACHINE"
|
||||
print(f"\n Result: {tag} (label={label}, confidence={conf:.2f})\n")
|
||||
|
||||
elif args.command == "train":
|
||||
train_cnn(args.data_dir, epochs=args.epochs, lr=args.lr, save_path=args.save)
|
||||
|
||||
elif args.command == "extract":
|
||||
feats = extract_features(args.image)
|
||||
print(json.dumps(feats, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user