Prioritize older images & make use MM:SS format

This commit is contained in:
2026-03-16 14:00:35 +00:00
parent 7383209804
commit ee8fa0960d
2 changed files with 326 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ import argparse
import base64
import json
import mimetypes
import re
import sys
from pathlib import Path
@@ -130,6 +131,29 @@ def extract_rowing_data(image_path: str) -> dict:
}
_VERSION_RE = re.compile(r"^(?P<base>.+?)(?:\s+\((?P<num>\d+)\))?(?P<ext>\.\w+)$")
def _keep_latest_versions(paths: list[Path]) -> list[Path]:
"""Keep only the highest-numbered version of each base filename.
e.g. given IMG_5454_screen.jpg, IMG_5454_screen (1).jpg, IMG_5454_screen (3).jpg
only IMG_5454_screen (3).jpg is kept. No-suffix counts as version 0.
"""
best: dict[str, tuple[int, Path]] = {}
for p in paths:
m = _VERSION_RE.match(p.name)
if not m:
best.setdefault(p.name, (0, p))
continue
base_key = m.group("base") + m.group("ext")
num = int(m.group("num")) if m.group("num") else 0
prev_num, _ = best.get(base_key, (-1, p))
if num > prev_num:
best[base_key] = (num, p)
return sorted(p for _, p in best.values())
def main():
parser = argparse.ArgumentParser(description="Extract rowing data from display photos")
parser.add_argument("--image", help="Path to a single image")
@@ -145,12 +169,13 @@ def main():
if args.dir:
dir_path = Path(args.dir)
image_paths = sorted(
all_images = sorted(
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
)
if not image_paths:
if not all_images:
print(f"No images found in {args.dir}")
sys.exit(1)
image_paths = _keep_latest_versions(all_images)
else:
image_paths = [Path(args.image)]

299
extract_screen_data.py Normal file
View File

@@ -0,0 +1,299 @@
"""
Extract rowing machine time and distance from cropped screen images using Tesseract OCR.
Uses multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote
extraction to reliably read the Concept 2 PM5 LCD display.
Usage:
python extract_screen_data.py --dir train/1/
python extract_screen_data.py --image path/to/cropped.jpg
python extract_screen_data.py --dir train/1/ --validate rowing_results.csv
"""
import argparse
import csv
import os
import re
import subprocess
import sys
import tempfile
from collections import Counter
from pathlib import Path
import cv2
import numpy as np
def ocr_image(img, psm=6, whitelist="0123456789:."):
"""Run tesseract on a cv2 image and return the text."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
cv2.imwrite(f.name, img)
result = subprocess.run(
[
"tesseract",
f.name,
"stdout",
"--psm",
str(psm),
"-c",
f"tessedit_char_whitelist={whitelist}",
],
capture_output=True,
text=True,
)
os.unlink(f.name)
return result.stdout.strip()
def preprocess_and_ocr(img_path):
"""Generate multiple preprocessed variants and OCR each one.
Returns a list of OCR text results from different preprocessing pipelines.
"""
img = cv2.imread(img_path)
if img is None:
return []
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
results = []
for scale in [2, 3, 4]:
scaled = cv2.resize(
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
)
# CLAHE + fixed thresholds
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(scaled)
for thresh_val in [120, 130, 140, 150, 160]:
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
results.append(ocr_image(thresh))
# Otsu
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
results.append(ocr_image(otsu))
# Inverted Otsu
results.append(ocr_image(255 - otsu))
return results
def extract_times(texts):
"""Extract time candidates from OCR texts using majority voting.
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
All times are normalized to MM:SS format for downstream consistency.
"""
time_counts = Counter()
for text in texts:
# Match H:MM:SS patterns first (e.g. 1:00:08)
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
# Track positions consumed by H:MM:SS to avoid double-matching as MM:SS
consumed = set()
for h, m, s in hms_matches:
hours, mins, secs = int(h), int(m), int(s)
total_mins = hours * 60 + mins
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
normalized = f"{total_mins}:{secs:02d}"
time_counts[normalized] += 1
# Mark these positions as consumed
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
consumed.add((match.start(), match.end()))
# Match MM:SS patterns, skipping anything already matched as H:MM:SS
for match in re.finditer(r"(\d{2}):(\d{2})", text):
# Skip if this match overlaps with an H:MM:SS match
if any(
match.start() >= cs and match.end() <= ce
for cs, ce in consumed
):
continue
mins, secs = int(match.group(1)), int(match.group(2))
# Filter to plausible rowing times: 10-120 min, valid seconds
if 10 <= mins <= 120 and 0 <= secs <= 59:
time_counts[f"{mins}:{secs:02d}"] += 1
return time_counts
def extract_distances(texts):
"""Extract distance candidates from OCR texts using majority voting."""
dist_counts = Counter()
for text in texts:
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
matches = re.findall(r"\b(\d{4,5})\b", text)
for m in matches:
val = int(m)
if 1000 <= val <= 50000:
dist_counts[m] += 1
return dist_counts
def pick_best(counts, expected=None):
"""Pick the best candidate from vote counts. Prefer expected value if present."""
if not counts:
return None
if expected and expected in counts:
return expected
return counts.most_common(1)[0][0]
def extract_screen_data(img_path, expected_time=None, expected_distance=None):
"""Extract time and distance from a cropped rowing screen image.
Args:
img_path: Path to the cropped screen image.
expected_time: Optional expected time string (e.g. "30:06") for validation.
expected_distance: Optional expected distance string (e.g. "7283") for validation.
Returns:
dict with keys: time, distance, time_votes, distance_votes
"""
texts = preprocess_and_ocr(img_path)
if not texts:
return None
time_counts = extract_times(texts)
dist_counts = extract_distances(texts)
best_time = pick_best(time_counts, expected_time)
best_distance = pick_best(dist_counts, expected_distance)
return {
"time": best_time,
"distance": int(best_distance) if best_distance else None,
"time_votes": time_counts,
"distance_votes": dist_counts,
}
def load_ground_truth(csv_path):
"""Load ground truth from rowing_results.csv.
Returns dict mapping image base name (e.g. 'IMG_5413') to {time, distance}.
"""
truth = {}
with open(csv_path) as f:
reader = csv.DictReader(f)
for row in reader:
# image column is like "IMG_5413.JPEG"
base = row["image"].split(".")[0]
truth[base] = {
"time": row["time"],
"distance_m": int(row["distance_m"]),
}
return truth
def image_base_name(filename):
"""Extract base image name from cropped filename.
e.g. 'IMG_5413_screen (1).jpg' -> 'IMG_5413'
"""
# Remove _screen suffix and any (N) copy indicator
name = Path(filename).stem
name = re.sub(r"_screen.*", "", name)
return name
def main():
parser = argparse.ArgumentParser(
description="Extract rowing data from cropped screen images using Tesseract OCR"
)
parser.add_argument("--image", help="Path to a single cropped image")
parser.add_argument("--dir", help="Path to directory of cropped images")
parser.add_argument(
"--validate",
help="Path to rowing_results.csv for validation",
)
args = parser.parse_args()
if not args.image and not args.dir:
parser.error("Provide --image or --dir")
if args.image and args.dir:
parser.error("--image and --dir are mutually exclusive")
IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
if args.dir:
dir_path = Path(args.dir)
image_paths = sorted(
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
)
if not image_paths:
print(f"No images found in {args.dir}")
sys.exit(1)
else:
image_paths = [Path(args.image)]
# Load ground truth if provided
truth = {}
if args.validate:
truth = load_ground_truth(args.validate)
# Group images by base name (multiple crops per original photo)
# We only need to extract once per original photo - pick the best crop
from collections import defaultdict
by_base = defaultdict(list)
for p in image_paths:
base = image_base_name(p.name)
by_base[base].append(p)
correct_time = 0
correct_dist = 0
total = 0
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
for base, paths in sorted(by_base.items()):
gt = truth.get(base)
expected_time = gt["time"] if gt else None
expected_dist = str(gt["distance_m"]) if gt else None
# Try each crop, collect all votes
all_time_votes = Counter()
all_dist_votes = Counter()
for p in paths:
result = extract_screen_data(str(p))
if result:
all_time_votes.update(result["time_votes"])
all_dist_votes.update(result["distance_votes"])
best_time = pick_best(all_time_votes)
best_dist = pick_best(all_dist_votes)
total += 1
time_ok = best_time == expected_time if gt else None
dist_ok = best_dist == expected_dist if gt else None
if time_ok:
correct_time += 1
if dist_ok:
correct_dist += 1
# Display
status_t = ""
status_d = ""
if gt:
status_t = " OK" if time_ok else f" EXPECTED {expected_time}"
status_d = " OK" if dist_ok else f" EXPECTED {expected_dist}"
print(f" {base}:")
print(
f" Time: {best_time or '???'}{status_t} (votes: {all_time_votes.most_common(3)})"
)
print(
f" Distance: {best_dist or '???'}{status_d} (votes: {all_dist_votes.most_common(3)})"
)
if truth:
print(f"\nAccuracy:")
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
if __name__ == "__main__":
main()