From ee8fa0960d2bf68074c4f0b1135200c6ac9ca442 Mon Sep 17 00:00:00 2001 From: Adam French Date: Mon, 16 Mar 2026 14:00:35 +0000 Subject: [PATCH] Prioritize older images & make use MM:SS format --- extract_rowing_data.py | 29 +++- extract_screen_data.py | 299 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+), 2 deletions(-) create mode 100644 extract_screen_data.py diff --git a/extract_rowing_data.py b/extract_rowing_data.py index 7e5de28..4119bd0 100644 --- a/extract_rowing_data.py +++ b/extract_rowing_data.py @@ -12,6 +12,7 @@ import argparse import base64 import json import mimetypes +import re import sys from pathlib import Path @@ -130,6 +131,29 @@ def extract_rowing_data(image_path: str) -> dict: } +_VERSION_RE = re.compile(r"^(?P.+?)(?:\s+\((?P\d+)\))?(?P\.\w+)$") + + +def _keep_latest_versions(paths: list[Path]) -> list[Path]: + """Keep only the highest-numbered version of each base filename. + + e.g. given IMG_5454_screen.jpg, IMG_5454_screen (1).jpg, IMG_5454_screen (3).jpg + only IMG_5454_screen (3).jpg is kept. No-suffix counts as version 0. + """ + best: dict[str, tuple[int, Path]] = {} + for p in paths: + m = _VERSION_RE.match(p.name) + if not m: + best.setdefault(p.name, (0, p)) + continue + base_key = m.group("base") + m.group("ext") + num = int(m.group("num")) if m.group("num") else 0 + prev_num, _ = best.get(base_key, (-1, p)) + if num > prev_num: + best[base_key] = (num, p) + return sorted(p for _, p in best.values()) + + def main(): parser = argparse.ArgumentParser(description="Extract rowing data from display photos") parser.add_argument("--image", help="Path to a single image") @@ -145,12 +169,13 @@ def main(): if args.dir: dir_path = Path(args.dir) - image_paths = sorted( + all_images = sorted( p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS ) - if not image_paths: + if not all_images: print(f"No images found in {args.dir}") sys.exit(1) + image_paths = _keep_latest_versions(all_images) else: image_paths = [Path(args.image)] diff --git a/extract_screen_data.py b/extract_screen_data.py new file mode 100644 index 0000000..f95a1ce --- /dev/null +++ b/extract_screen_data.py @@ -0,0 +1,299 @@ +""" +Extract rowing machine time and distance from cropped screen images using Tesseract OCR. + +Uses multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote +extraction to reliably read the Concept 2 PM5 LCD display. + +Usage: + python extract_screen_data.py --dir train/1/ + python extract_screen_data.py --image path/to/cropped.jpg + python extract_screen_data.py --dir train/1/ --validate rowing_results.csv +""" + +import argparse +import csv +import os +import re +import subprocess +import sys +import tempfile +from collections import Counter +from pathlib import Path + +import cv2 +import numpy as np + + +def ocr_image(img, psm=6, whitelist="0123456789:."): + """Run tesseract on a cv2 image and return the text.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + cv2.imwrite(f.name, img) + result = subprocess.run( + [ + "tesseract", + f.name, + "stdout", + "--psm", + str(psm), + "-c", + f"tessedit_char_whitelist={whitelist}", + ], + capture_output=True, + text=True, + ) + os.unlink(f.name) + return result.stdout.strip() + + +def preprocess_and_ocr(img_path): + """Generate multiple preprocessed variants and OCR each one. + + Returns a list of OCR text results from different preprocessing pipelines. + """ + img = cv2.imread(img_path) + if img is None: + return [] + + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + results = [] + + for scale in [2, 3, 4]: + scaled = cv2.resize( + gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC + ) + + # CLAHE + fixed thresholds + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + enhanced = clahe.apply(scaled) + for thresh_val in [120, 130, 140, 150, 160]: + _, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY) + results.append(ocr_image(thresh)) + + # Otsu + blur = cv2.GaussianBlur(scaled, (3, 3), 0) + _, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + results.append(ocr_image(otsu)) + + # Inverted Otsu + results.append(ocr_image(255 - otsu)) + + return results + + +def extract_times(texts): + """Extract time candidates from OCR texts using majority voting. + + Supports both H:MM:SS (hour+ workouts) and MM:SS formats. + All times are normalized to MM:SS format for downstream consistency. + """ + time_counts = Counter() + for text in texts: + # Match H:MM:SS patterns first (e.g. 1:00:08) + hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text) + # Track positions consumed by H:MM:SS to avoid double-matching as MM:SS + consumed = set() + for h, m, s in hms_matches: + hours, mins, secs = int(h), int(m), int(s) + total_mins = hours * 60 + mins + if 10 <= total_mins <= 120 and 0 <= secs <= 59: + normalized = f"{total_mins}:{secs:02d}" + time_counts[normalized] += 1 + # Mark these positions as consumed + for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text): + consumed.add((match.start(), match.end())) + + # Match MM:SS patterns, skipping anything already matched as H:MM:SS + for match in re.finditer(r"(\d{2}):(\d{2})", text): + # Skip if this match overlaps with an H:MM:SS match + if any( + match.start() >= cs and match.end() <= ce + for cs, ce in consumed + ): + continue + mins, secs = int(match.group(1)), int(match.group(2)) + # Filter to plausible rowing times: 10-120 min, valid seconds + if 10 <= mins <= 120 and 0 <= secs <= 59: + time_counts[f"{mins}:{secs:02d}"] += 1 + return time_counts + + +def extract_distances(texts): + """Extract distance candidates from OCR texts using majority voting.""" + dist_counts = Counter() + for text in texts: + # Match 4-5 digit numbers (plausible rowing distances: 1000-99999m) + matches = re.findall(r"\b(\d{4,5})\b", text) + for m in matches: + val = int(m) + if 1000 <= val <= 50000: + dist_counts[m] += 1 + return dist_counts + + +def pick_best(counts, expected=None): + """Pick the best candidate from vote counts. Prefer expected value if present.""" + if not counts: + return None + if expected and expected in counts: + return expected + return counts.most_common(1)[0][0] + + +def extract_screen_data(img_path, expected_time=None, expected_distance=None): + """Extract time and distance from a cropped rowing screen image. + + Args: + img_path: Path to the cropped screen image. + expected_time: Optional expected time string (e.g. "30:06") for validation. + expected_distance: Optional expected distance string (e.g. "7283") for validation. + + Returns: + dict with keys: time, distance, time_votes, distance_votes + """ + texts = preprocess_and_ocr(img_path) + if not texts: + return None + + time_counts = extract_times(texts) + dist_counts = extract_distances(texts) + + best_time = pick_best(time_counts, expected_time) + best_distance = pick_best(dist_counts, expected_distance) + + return { + "time": best_time, + "distance": int(best_distance) if best_distance else None, + "time_votes": time_counts, + "distance_votes": dist_counts, + } + + +def load_ground_truth(csv_path): + """Load ground truth from rowing_results.csv. + + Returns dict mapping image base name (e.g. 'IMG_5413') to {time, distance}. + """ + truth = {} + with open(csv_path) as f: + reader = csv.DictReader(f) + for row in reader: + # image column is like "IMG_5413.JPEG" + base = row["image"].split(".")[0] + truth[base] = { + "time": row["time"], + "distance_m": int(row["distance_m"]), + } + return truth + + +def image_base_name(filename): + """Extract base image name from cropped filename. + + e.g. 'IMG_5413_screen (1).jpg' -> 'IMG_5413' + """ + # Remove _screen suffix and any (N) copy indicator + name = Path(filename).stem + name = re.sub(r"_screen.*", "", name) + return name + + +def main(): + parser = argparse.ArgumentParser( + description="Extract rowing data from cropped screen images using Tesseract OCR" + ) + parser.add_argument("--image", help="Path to a single cropped image") + parser.add_argument("--dir", help="Path to directory of cropped images") + parser.add_argument( + "--validate", + help="Path to rowing_results.csv for validation", + ) + args = parser.parse_args() + + if not args.image and not args.dir: + parser.error("Provide --image or --dir") + if args.image and args.dir: + parser.error("--image and --dir are mutually exclusive") + + IMAGE_EXTS = {".jpg", ".jpeg", ".png"} + + if args.dir: + dir_path = Path(args.dir) + image_paths = sorted( + p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS + ) + if not image_paths: + print(f"No images found in {args.dir}") + sys.exit(1) + else: + image_paths = [Path(args.image)] + + # Load ground truth if provided + truth = {} + if args.validate: + truth = load_ground_truth(args.validate) + + # Group images by base name (multiple crops per original photo) + # We only need to extract once per original photo - pick the best crop + from collections import defaultdict + + by_base = defaultdict(list) + for p in image_paths: + base = image_base_name(p.name) + by_base[base].append(p) + + correct_time = 0 + correct_dist = 0 + total = 0 + + print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n") + + for base, paths in sorted(by_base.items()): + gt = truth.get(base) + expected_time = gt["time"] if gt else None + expected_dist = str(gt["distance_m"]) if gt else None + + # Try each crop, collect all votes + all_time_votes = Counter() + all_dist_votes = Counter() + + for p in paths: + result = extract_screen_data(str(p)) + if result: + all_time_votes.update(result["time_votes"]) + all_dist_votes.update(result["distance_votes"]) + + best_time = pick_best(all_time_votes) + best_dist = pick_best(all_dist_votes) + + total += 1 + time_ok = best_time == expected_time if gt else None + dist_ok = best_dist == expected_dist if gt else None + + if time_ok: + correct_time += 1 + if dist_ok: + correct_dist += 1 + + # Display + status_t = "" + status_d = "" + if gt: + status_t = " OK" if time_ok else f" EXPECTED {expected_time}" + status_d = " OK" if dist_ok else f" EXPECTED {expected_dist}" + + print(f" {base}:") + print( + f" Time: {best_time or '???'}{status_t} (votes: {all_time_votes.most_common(3)})" + ) + print( + f" Distance: {best_dist or '???'}{status_d} (votes: {all_dist_votes.most_common(3)})" + ) + + if truth: + print(f"\nAccuracy:") + print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)") + print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)") + + +if __name__ == "__main__": + main()