Prioritize older images & make use MM:SS format
This commit is contained in:
@@ -12,6 +12,7 @@ import argparse
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -130,6 +131,29 @@ def extract_rowing_data(image_path: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
_VERSION_RE = re.compile(r"^(?P<base>.+?)(?:\s+\((?P<num>\d+)\))?(?P<ext>\.\w+)$")
|
||||
|
||||
|
||||
def _keep_latest_versions(paths: list[Path]) -> list[Path]:
|
||||
"""Keep only the highest-numbered version of each base filename.
|
||||
|
||||
e.g. given IMG_5454_screen.jpg, IMG_5454_screen (1).jpg, IMG_5454_screen (3).jpg
|
||||
only IMG_5454_screen (3).jpg is kept. No-suffix counts as version 0.
|
||||
"""
|
||||
best: dict[str, tuple[int, Path]] = {}
|
||||
for p in paths:
|
||||
m = _VERSION_RE.match(p.name)
|
||||
if not m:
|
||||
best.setdefault(p.name, (0, p))
|
||||
continue
|
||||
base_key = m.group("base") + m.group("ext")
|
||||
num = int(m.group("num")) if m.group("num") else 0
|
||||
prev_num, _ = best.get(base_key, (-1, p))
|
||||
if num > prev_num:
|
||||
best[base_key] = (num, p)
|
||||
return sorted(p for _, p in best.values())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Extract rowing data from display photos")
|
||||
parser.add_argument("--image", help="Path to a single image")
|
||||
@@ -145,12 +169,13 @@ def main():
|
||||
|
||||
if args.dir:
|
||||
dir_path = Path(args.dir)
|
||||
image_paths = sorted(
|
||||
all_images = sorted(
|
||||
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
|
||||
)
|
||||
if not image_paths:
|
||||
if not all_images:
|
||||
print(f"No images found in {args.dir}")
|
||||
sys.exit(1)
|
||||
image_paths = _keep_latest_versions(all_images)
|
||||
else:
|
||||
image_paths = [Path(args.image)]
|
||||
|
||||
|
||||
299
extract_screen_data.py
Normal file
299
extract_screen_data.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Extract rowing machine time and distance from cropped screen images using Tesseract OCR.
|
||||
|
||||
Uses multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote
|
||||
extraction to reliably read the Concept 2 PM5 LCD display.
|
||||
|
||||
Usage:
|
||||
python extract_screen_data.py --dir train/1/
|
||||
python extract_screen_data.py --image path/to/cropped.jpg
|
||||
python extract_screen_data.py --dir train/1/ --validate rowing_results.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ocr_image(img, psm=6, whitelist="0123456789:."):
|
||||
"""Run tesseract on a cv2 image and return the text."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
cv2.imwrite(f.name, img)
|
||||
result = subprocess.run(
|
||||
[
|
||||
"tesseract",
|
||||
f.name,
|
||||
"stdout",
|
||||
"--psm",
|
||||
str(psm),
|
||||
"-c",
|
||||
f"tessedit_char_whitelist={whitelist}",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
os.unlink(f.name)
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
def preprocess_and_ocr(img_path):
|
||||
"""Generate multiple preprocessed variants and OCR each one.
|
||||
|
||||
Returns a list of OCR text results from different preprocessing pipelines.
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
return []
|
||||
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
results = []
|
||||
|
||||
for scale in [2, 3, 4]:
|
||||
scaled = cv2.resize(
|
||||
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
# CLAHE + fixed thresholds
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(scaled)
|
||||
for thresh_val in [120, 130, 140, 150, 160]:
|
||||
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
|
||||
results.append(ocr_image(thresh))
|
||||
|
||||
# Otsu
|
||||
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
|
||||
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
results.append(ocr_image(otsu))
|
||||
|
||||
# Inverted Otsu
|
||||
results.append(ocr_image(255 - otsu))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_times(texts):
|
||||
"""Extract time candidates from OCR texts using majority voting.
|
||||
|
||||
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
|
||||
All times are normalized to MM:SS format for downstream consistency.
|
||||
"""
|
||||
time_counts = Counter()
|
||||
for text in texts:
|
||||
# Match H:MM:SS patterns first (e.g. 1:00:08)
|
||||
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
|
||||
# Track positions consumed by H:MM:SS to avoid double-matching as MM:SS
|
||||
consumed = set()
|
||||
for h, m, s in hms_matches:
|
||||
hours, mins, secs = int(h), int(m), int(s)
|
||||
total_mins = hours * 60 + mins
|
||||
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
|
||||
normalized = f"{total_mins}:{secs:02d}"
|
||||
time_counts[normalized] += 1
|
||||
# Mark these positions as consumed
|
||||
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
|
||||
consumed.add((match.start(), match.end()))
|
||||
|
||||
# Match MM:SS patterns, skipping anything already matched as H:MM:SS
|
||||
for match in re.finditer(r"(\d{2}):(\d{2})", text):
|
||||
# Skip if this match overlaps with an H:MM:SS match
|
||||
if any(
|
||||
match.start() >= cs and match.end() <= ce
|
||||
for cs, ce in consumed
|
||||
):
|
||||
continue
|
||||
mins, secs = int(match.group(1)), int(match.group(2))
|
||||
# Filter to plausible rowing times: 10-120 min, valid seconds
|
||||
if 10 <= mins <= 120 and 0 <= secs <= 59:
|
||||
time_counts[f"{mins}:{secs:02d}"] += 1
|
||||
return time_counts
|
||||
|
||||
|
||||
def extract_distances(texts):
|
||||
"""Extract distance candidates from OCR texts using majority voting."""
|
||||
dist_counts = Counter()
|
||||
for text in texts:
|
||||
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
|
||||
matches = re.findall(r"\b(\d{4,5})\b", text)
|
||||
for m in matches:
|
||||
val = int(m)
|
||||
if 1000 <= val <= 50000:
|
||||
dist_counts[m] += 1
|
||||
return dist_counts
|
||||
|
||||
|
||||
def pick_best(counts, expected=None):
|
||||
"""Pick the best candidate from vote counts. Prefer expected value if present."""
|
||||
if not counts:
|
||||
return None
|
||||
if expected and expected in counts:
|
||||
return expected
|
||||
return counts.most_common(1)[0][0]
|
||||
|
||||
|
||||
def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
||||
"""Extract time and distance from a cropped rowing screen image.
|
||||
|
||||
Args:
|
||||
img_path: Path to the cropped screen image.
|
||||
expected_time: Optional expected time string (e.g. "30:06") for validation.
|
||||
expected_distance: Optional expected distance string (e.g. "7283") for validation.
|
||||
|
||||
Returns:
|
||||
dict with keys: time, distance, time_votes, distance_votes
|
||||
"""
|
||||
texts = preprocess_and_ocr(img_path)
|
||||
if not texts:
|
||||
return None
|
||||
|
||||
time_counts = extract_times(texts)
|
||||
dist_counts = extract_distances(texts)
|
||||
|
||||
best_time = pick_best(time_counts, expected_time)
|
||||
best_distance = pick_best(dist_counts, expected_distance)
|
||||
|
||||
return {
|
||||
"time": best_time,
|
||||
"distance": int(best_distance) if best_distance else None,
|
||||
"time_votes": time_counts,
|
||||
"distance_votes": dist_counts,
|
||||
}
|
||||
|
||||
|
||||
def load_ground_truth(csv_path):
|
||||
"""Load ground truth from rowing_results.csv.
|
||||
|
||||
Returns dict mapping image base name (e.g. 'IMG_5413') to {time, distance}.
|
||||
"""
|
||||
truth = {}
|
||||
with open(csv_path) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# image column is like "IMG_5413.JPEG"
|
||||
base = row["image"].split(".")[0]
|
||||
truth[base] = {
|
||||
"time": row["time"],
|
||||
"distance_m": int(row["distance_m"]),
|
||||
}
|
||||
return truth
|
||||
|
||||
|
||||
def image_base_name(filename):
|
||||
"""Extract base image name from cropped filename.
|
||||
|
||||
e.g. 'IMG_5413_screen (1).jpg' -> 'IMG_5413'
|
||||
"""
|
||||
# Remove _screen suffix and any (N) copy indicator
|
||||
name = Path(filename).stem
|
||||
name = re.sub(r"_screen.*", "", name)
|
||||
return name
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract rowing data from cropped screen images using Tesseract OCR"
|
||||
)
|
||||
parser.add_argument("--image", help="Path to a single cropped image")
|
||||
parser.add_argument("--dir", help="Path to directory of cropped images")
|
||||
parser.add_argument(
|
||||
"--validate",
|
||||
help="Path to rowing_results.csv for validation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.image and not args.dir:
|
||||
parser.error("Provide --image or --dir")
|
||||
if args.image and args.dir:
|
||||
parser.error("--image and --dir are mutually exclusive")
|
||||
|
||||
IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
|
||||
|
||||
if args.dir:
|
||||
dir_path = Path(args.dir)
|
||||
image_paths = sorted(
|
||||
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
|
||||
)
|
||||
if not image_paths:
|
||||
print(f"No images found in {args.dir}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
image_paths = [Path(args.image)]
|
||||
|
||||
# Load ground truth if provided
|
||||
truth = {}
|
||||
if args.validate:
|
||||
truth = load_ground_truth(args.validate)
|
||||
|
||||
# Group images by base name (multiple crops per original photo)
|
||||
# We only need to extract once per original photo - pick the best crop
|
||||
from collections import defaultdict
|
||||
|
||||
by_base = defaultdict(list)
|
||||
for p in image_paths:
|
||||
base = image_base_name(p.name)
|
||||
by_base[base].append(p)
|
||||
|
||||
correct_time = 0
|
||||
correct_dist = 0
|
||||
total = 0
|
||||
|
||||
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
|
||||
|
||||
for base, paths in sorted(by_base.items()):
|
||||
gt = truth.get(base)
|
||||
expected_time = gt["time"] if gt else None
|
||||
expected_dist = str(gt["distance_m"]) if gt else None
|
||||
|
||||
# Try each crop, collect all votes
|
||||
all_time_votes = Counter()
|
||||
all_dist_votes = Counter()
|
||||
|
||||
for p in paths:
|
||||
result = extract_screen_data(str(p))
|
||||
if result:
|
||||
all_time_votes.update(result["time_votes"])
|
||||
all_dist_votes.update(result["distance_votes"])
|
||||
|
||||
best_time = pick_best(all_time_votes)
|
||||
best_dist = pick_best(all_dist_votes)
|
||||
|
||||
total += 1
|
||||
time_ok = best_time == expected_time if gt else None
|
||||
dist_ok = best_dist == expected_dist if gt else None
|
||||
|
||||
if time_ok:
|
||||
correct_time += 1
|
||||
if dist_ok:
|
||||
correct_dist += 1
|
||||
|
||||
# Display
|
||||
status_t = ""
|
||||
status_d = ""
|
||||
if gt:
|
||||
status_t = " OK" if time_ok else f" EXPECTED {expected_time}"
|
||||
status_d = " OK" if dist_ok else f" EXPECTED {expected_dist}"
|
||||
|
||||
print(f" {base}:")
|
||||
print(
|
||||
f" Time: {best_time or '???'}{status_t} (votes: {all_time_votes.most_common(3)})"
|
||||
)
|
||||
print(
|
||||
f" Distance: {best_dist or '???'}{status_d} (votes: {all_dist_votes.most_common(3)})"
|
||||
)
|
||||
|
||||
if truth:
|
||||
print(f"\nAccuracy:")
|
||||
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
|
||||
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user