From c447cc41eae3e1b713bf6f69e1f7b36aa5964dc7 Mon Sep 17 00:00:00 2001 From: Adam French Date: Mon, 16 Mar 2026 15:20:59 +0000 Subject: [PATCH] Optimise parameters --- README.md | 92 ++++++++++++++++++++++++++++++ extract_screen_data.py | 124 +++++++++++++++++++++++++++++++++++------ 2 files changed, 200 insertions(+), 16 deletions(-) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..b08773a --- /dev/null +++ b/README.md @@ -0,0 +1,92 @@ +# Rowing Stats + +Extract workout data from photos of Concept 2 PM5 rowing machine displays using computer vision and Claude's vision API. + +## How It Works + +Photos go through a three-stage pipeline: + +``` +photos/ → crop_to_screen.py → screen_classifier.py → extract_screen_data.py → rowing_results.csv +``` + +1. **Screen Detection** (`crop_to_screen.py`) — Finds and perspective-corrects the LCD screen region using OpenCV edge detection, contour filtering, and morphological operations. Candidates are scored by `edge_density × area × rectangularity`. +2. **Classification** (`screen_classifier.py`) — Filters out non-rowing images. Supports a rule-based feature scorer (no training needed) and a 4-layer CNN with batch norm. +3. **Data Extraction** (`extract_screen_data.py`) — Extracts time and distance from cropped screen images using Tesseract OCR with multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote extraction. + +There is also `extract_rowing_data.py`, which uses Claude Haiku's vision API instead of Tesseract for data extraction. This serves as a reference/test for validating OCR accuracy but is more expensive to run due to API costs. + +There is also an Optuna-based hyperparameter tuner (`optimize_crop.py`) for the screen detection parameters. + +## Setup + +### Dependencies + +``` +pip install anthropic torch torchvision opencv-python Pillow numpy optuna +``` + +### API Key + +Create a `.env` file with your Anthropic API key: + +``` +ANTHROPIC_API_KEY=sk-ant-... +``` + +## Usage + +### Full pipeline + +```bash +# 1. Crop screens from photos +python crop_to_screen.py photos/ cropped/ + +# 2. Classify — keep only rowing displays +python screen_classifier.py predict --dir cropped/ + +# 3. Extract workout data via Tesseract OCR +python extract_screen_data.py --dir cropped/ + +# 3b. (Test) Extract via Claude API — more expensive, useful for validating OCR accuracy +python extract_rowing_data.py --dir photos/ +``` + +### Individual commands + +```bash +# Classify a single image (feature-based or CNN) +python screen_classifier.py predict --image path/to/img.jpg +python screen_classifier.py predict --image path/to/img.jpg --mode cnn + +# Extract data from a single image (Tesseract OCR) +python extract_screen_data.py --image path/to/img.jpg + +# Extract data from a single image (Claude API — for testing/validation) +python extract_rowing_data.py --image path/to/img.jpg + +# Train the CNN classifier +python screen_classifier.py train --data-dir train/ + +# Optimize crop detection parameters +python optimize_crop.py --n-trials 300 --photos-dir photos/ +``` + +## Training Data + +The CNN classifier trains on labeled images in `train/`: + +- `train/0/` — non-rowing images (negatives) +- `train/1/` — rowing display images (positives) + +The trained model is saved as `screen_classifier_model.pth`. + +## Validation + +Extracted metrics are validated against sensible bounds: + +| Metric | Min | Max | +| -------- | --------- | --------- | +| Distance | 100 m | 100,000 m | +| Time | 30 s | 2 hrs | +| Pace | 1:20/500m | 2:30/500m | diff --git a/extract_screen_data.py b/extract_screen_data.py index f95a1ce..ce45925 100644 --- a/extract_screen_data.py +++ b/extract_screen_data.py @@ -48,7 +48,8 @@ def ocr_image(img, psm=6, whitelist="0123456789:."): def preprocess_and_ocr(img_path): """Generate multiple preprocessed variants and OCR each one. - Returns a list of OCR text results from different preprocessing pipelines. + Returns a list of (label, text) tuples from different preprocessing pipelines. + Labels like 's2_clahe_t120', 's3_otsu', 's4_otsu_inv'. """ img = cv2.imread(img_path) if img is None: @@ -57,35 +58,37 @@ def preprocess_and_ocr(img_path): gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) results = [] - for scale in [2, 3, 4]: + for scale in [2, 3]: scaled = cv2.resize( gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC ) - # CLAHE + fixed thresholds - clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) - enhanced = clahe.apply(scaled) - for thresh_val in [120, 130, 140, 150, 160]: - _, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY) - results.append(ocr_image(thresh)) + # CLAHE + fixed thresholds (only t120/t130 at scale 2 are useful) + if scale == 2: + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(scaled) + for thresh_val in [120, 130]: + _, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY) + results.append((f"s{scale}_clahe_t{thresh_val}", ocr_image(thresh))) # Otsu blur = cv2.GaussianBlur(scaled, (3, 3), 0) _, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - results.append(ocr_image(otsu)) + results.append((f"s{scale}_otsu", ocr_image(otsu))) # Inverted Otsu - results.append(ocr_image(255 - otsu)) + results.append((f"s{scale}_otsu_inv", ocr_image(255 - otsu))) return results -def extract_times(texts): +def extract_times(labeled_texts): """Extract time candidates from OCR texts using majority voting. Supports both H:MM:SS (hour+ workouts) and MM:SS formats. All times are normalized to MM:SS format for downstream consistency. """ + texts = [t for _, t in labeled_texts] time_counts = Counter() for text in texts: # Match H:MM:SS patterns first (e.g. 1:00:08) @@ -117,8 +120,9 @@ def extract_times(texts): return time_counts -def extract_distances(texts): +def extract_distances(labeled_texts): """Extract distance candidates from OCR texts using majority voting.""" + texts = [t for _, t in labeled_texts] dist_counts = Counter() for text in texts: # Match 4-5 digit numbers (plausible rowing distances: 1000-99999m) @@ -130,6 +134,50 @@ def extract_distances(texts): return dist_counts +def extract_times_by_method(labeled_texts): + """Extract time candidates grouped by method label. + + Returns dict mapping method label to set of extracted time strings. + """ + by_method = {} + for label, text in labeled_texts: + times = set() + hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text) + consumed = set() + for h, m, s in hms_matches: + hours, mins, secs = int(h), int(m), int(s) + total_mins = hours * 60 + mins + if 10 <= total_mins <= 120 and 0 <= secs <= 59: + times.add(f"{total_mins}:{secs:02d}") + for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text): + consumed.add((match.start(), match.end())) + for match in re.finditer(r"(\d{2}):(\d{2})", text): + if any(match.start() >= cs and match.end() <= ce for cs, ce in consumed): + continue + mins, secs = int(match.group(1)), int(match.group(2)) + if 10 <= mins <= 120 and 0 <= secs <= 59: + times.add(f"{mins}:{secs:02d}") + by_method[label] = times + return by_method + + +def extract_distances_by_method(labeled_texts): + """Extract distance candidates grouped by method label. + + Returns dict mapping method label to set of extracted distance strings. + """ + by_method = {} + for label, text in labeled_texts: + dists = set() + matches = re.findall(r"\b(\d{4,5})\b", text) + for m in matches: + val = int(m) + if 1000 <= val <= 50000: + dists.add(m) + by_method[label] = dists + return by_method + + def pick_best(counts, expected=None): """Pick the best candidate from vote counts. Prefer expected value if present.""" if not counts: @@ -150,12 +198,12 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None): Returns: dict with keys: time, distance, time_votes, distance_votes """ - texts = preprocess_and_ocr(img_path) - if not texts: + labeled_texts = preprocess_and_ocr(img_path) + if not labeled_texts: return None - time_counts = extract_times(texts) - dist_counts = extract_distances(texts) + time_counts = extract_times(labeled_texts) + dist_counts = extract_distances(labeled_texts) best_time = pick_best(time_counts, expected_time) best_distance = pick_best(dist_counts, expected_distance) @@ -165,6 +213,7 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None): "distance": int(best_distance) if best_distance else None, "time_votes": time_counts, "distance_votes": dist_counts, + "labeled_texts": labeled_texts, } @@ -245,6 +294,12 @@ def main(): correct_dist = 0 total = 0 + # Per-method accuracy tracking: method -> {time_correct, time_wrong, time_abstain, ...} + method_stats = defaultdict(lambda: { + "time_correct": 0, "time_wrong": 0, "time_abstain": 0, + "dist_correct": 0, "dist_wrong": 0, "dist_abstain": 0, + }) + print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n") for base, paths in sorted(by_base.items()): @@ -255,12 +310,14 @@ def main(): # Try each crop, collect all votes all_time_votes = Counter() all_dist_votes = Counter() + all_labeled_texts = [] for p in paths: result = extract_screen_data(str(p)) if result: all_time_votes.update(result["time_votes"]) all_dist_votes.update(result["distance_votes"]) + all_labeled_texts.extend(result["labeled_texts"]) best_time = pick_best(all_time_votes) best_dist = pick_best(all_dist_votes) @@ -274,6 +331,27 @@ def main(): if dist_ok: correct_dist += 1 + # Track per-method accuracy against ground truth + if gt: + times_by_method = extract_times_by_method(all_labeled_texts) + dists_by_method = extract_distances_by_method(all_labeled_texts) + + for method, times in times_by_method.items(): + if not times: + method_stats[method]["time_abstain"] += 1 + elif expected_time in times: + method_stats[method]["time_correct"] += 1 + else: + method_stats[method]["time_wrong"] += 1 + + for method, dists in dists_by_method.items(): + if not dists: + method_stats[method]["dist_abstain"] += 1 + elif expected_dist in dists: + method_stats[method]["dist_correct"] += 1 + else: + method_stats[method]["dist_wrong"] += 1 + # Display status_t = "" status_d = "" @@ -294,6 +372,20 @@ def main(): print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)") print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)") + if method_stats: + print(f"\nMethod Accuracy:") + print( + f" {'Method':<20s} {'Time OK':>8s} {'Time Wrong':>11s} {'Time Abstain':>13s}" + f" {'Dist OK':>8s} {'Dist Wrong':>11s} {'Dist Abstain':>13s}" + ) + for method in sorted(method_stats): + s = method_stats[method] + print( + f" {method:<20s} {s['time_correct']:>8d} {s['time_wrong']:>11d}" + f" {s['time_abstain']:>13d} {s['dist_correct']:>8d}" + f" {s['dist_wrong']:>11d} {s['dist_abstain']:>13d}" + ) + if __name__ == "__main__": main()