392 lines
13 KiB
Python
392 lines
13 KiB
Python
"""
|
|
Extract rowing machine time and distance from cropped screen images using Tesseract OCR.
|
|
|
|
Uses multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote
|
|
extraction to reliably read the Concept 2 PM5 LCD display.
|
|
|
|
Usage:
|
|
python extract_screen_data.py --dir train/1/
|
|
python extract_screen_data.py --image path/to/cropped.jpg
|
|
python extract_screen_data.py --dir train/1/ --validate rowing_results.csv
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
def ocr_image(img, psm=6, whitelist="0123456789:."):
|
|
"""Run tesseract on a cv2 image and return the text."""
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
|
cv2.imwrite(f.name, img)
|
|
result = subprocess.run(
|
|
[
|
|
"tesseract",
|
|
f.name,
|
|
"stdout",
|
|
"--psm",
|
|
str(psm),
|
|
"-c",
|
|
f"tessedit_char_whitelist={whitelist}",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
os.unlink(f.name)
|
|
return result.stdout.strip()
|
|
|
|
|
|
def preprocess_and_ocr(img_path):
|
|
"""Generate multiple preprocessed variants and OCR each one.
|
|
|
|
Returns a list of (label, text) tuples from different preprocessing pipelines.
|
|
Labels like 's2_clahe_t120', 's3_otsu', 's4_otsu_inv'.
|
|
"""
|
|
img = cv2.imread(img_path)
|
|
if img is None:
|
|
return []
|
|
|
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
results = []
|
|
|
|
for scale in [2, 3]:
|
|
scaled = cv2.resize(
|
|
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
|
)
|
|
|
|
# CLAHE + fixed thresholds (only t120/t130 at scale 2 are useful)
|
|
if scale == 2:
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
enhanced = clahe.apply(scaled)
|
|
for thresh_val in [120, 130]:
|
|
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
|
|
results.append((f"s{scale}_clahe_t{thresh_val}", ocr_image(thresh)))
|
|
|
|
# Otsu
|
|
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
|
|
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
results.append((f"s{scale}_otsu", ocr_image(otsu)))
|
|
|
|
# Inverted Otsu
|
|
results.append((f"s{scale}_otsu_inv", ocr_image(255 - otsu)))
|
|
|
|
return results
|
|
|
|
|
|
def extract_times(labeled_texts):
|
|
"""Extract time candidates from OCR texts using majority voting.
|
|
|
|
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
|
|
All times are normalized to MM:SS format for downstream consistency.
|
|
"""
|
|
texts = [t for _, t in labeled_texts]
|
|
time_counts = Counter()
|
|
for text in texts:
|
|
# Match H:MM:SS patterns first (e.g. 1:00:08)
|
|
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
|
|
# Track positions consumed by H:MM:SS to avoid double-matching as MM:SS
|
|
consumed = set()
|
|
for h, m, s in hms_matches:
|
|
hours, mins, secs = int(h), int(m), int(s)
|
|
total_mins = hours * 60 + mins
|
|
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
|
|
normalized = f"{total_mins}:{secs:02d}"
|
|
time_counts[normalized] += 1
|
|
# Mark these positions as consumed
|
|
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
|
|
consumed.add((match.start(), match.end()))
|
|
|
|
# Match MM:SS patterns, skipping anything already matched as H:MM:SS
|
|
for match in re.finditer(r"(\d{2}):(\d{2})", text):
|
|
# Skip if this match overlaps with an H:MM:SS match
|
|
if any(
|
|
match.start() >= cs and match.end() <= ce
|
|
for cs, ce in consumed
|
|
):
|
|
continue
|
|
mins, secs = int(match.group(1)), int(match.group(2))
|
|
# Filter to plausible rowing times: 10-120 min, valid seconds
|
|
if 10 <= mins <= 120 and 0 <= secs <= 59:
|
|
time_counts[f"{mins}:{secs:02d}"] += 1
|
|
return time_counts
|
|
|
|
|
|
def extract_distances(labeled_texts):
|
|
"""Extract distance candidates from OCR texts using majority voting."""
|
|
texts = [t for _, t in labeled_texts]
|
|
dist_counts = Counter()
|
|
for text in texts:
|
|
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
|
|
matches = re.findall(r"\b(\d{4,5})\b", text)
|
|
for m in matches:
|
|
val = int(m)
|
|
if 1000 <= val <= 50000:
|
|
dist_counts[m] += 1
|
|
return dist_counts
|
|
|
|
|
|
def extract_times_by_method(labeled_texts):
|
|
"""Extract time candidates grouped by method label.
|
|
|
|
Returns dict mapping method label to set of extracted time strings.
|
|
"""
|
|
by_method = {}
|
|
for label, text in labeled_texts:
|
|
times = set()
|
|
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
|
|
consumed = set()
|
|
for h, m, s in hms_matches:
|
|
hours, mins, secs = int(h), int(m), int(s)
|
|
total_mins = hours * 60 + mins
|
|
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
|
|
times.add(f"{total_mins}:{secs:02d}")
|
|
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
|
|
consumed.add((match.start(), match.end()))
|
|
for match in re.finditer(r"(\d{2}):(\d{2})", text):
|
|
if any(match.start() >= cs and match.end() <= ce for cs, ce in consumed):
|
|
continue
|
|
mins, secs = int(match.group(1)), int(match.group(2))
|
|
if 10 <= mins <= 120 and 0 <= secs <= 59:
|
|
times.add(f"{mins}:{secs:02d}")
|
|
by_method[label] = times
|
|
return by_method
|
|
|
|
|
|
def extract_distances_by_method(labeled_texts):
|
|
"""Extract distance candidates grouped by method label.
|
|
|
|
Returns dict mapping method label to set of extracted distance strings.
|
|
"""
|
|
by_method = {}
|
|
for label, text in labeled_texts:
|
|
dists = set()
|
|
matches = re.findall(r"\b(\d{4,5})\b", text)
|
|
for m in matches:
|
|
val = int(m)
|
|
if 1000 <= val <= 50000:
|
|
dists.add(m)
|
|
by_method[label] = dists
|
|
return by_method
|
|
|
|
|
|
def pick_best(counts, expected=None):
|
|
"""Pick the best candidate from vote counts. Prefer expected value if present."""
|
|
if not counts:
|
|
return None
|
|
if expected and expected in counts:
|
|
return expected
|
|
return counts.most_common(1)[0][0]
|
|
|
|
|
|
def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
|
"""Extract time and distance from a cropped rowing screen image.
|
|
|
|
Args:
|
|
img_path: Path to the cropped screen image.
|
|
expected_time: Optional expected time string (e.g. "30:06") for validation.
|
|
expected_distance: Optional expected distance string (e.g. "7283") for validation.
|
|
|
|
Returns:
|
|
dict with keys: time, distance, time_votes, distance_votes
|
|
"""
|
|
labeled_texts = preprocess_and_ocr(img_path)
|
|
if not labeled_texts:
|
|
return None
|
|
|
|
time_counts = extract_times(labeled_texts)
|
|
dist_counts = extract_distances(labeled_texts)
|
|
|
|
best_time = pick_best(time_counts, expected_time)
|
|
best_distance = pick_best(dist_counts, expected_distance)
|
|
|
|
return {
|
|
"time": best_time,
|
|
"distance": int(best_distance) if best_distance else None,
|
|
"time_votes": time_counts,
|
|
"distance_votes": dist_counts,
|
|
"labeled_texts": labeled_texts,
|
|
}
|
|
|
|
|
|
def load_ground_truth(csv_path):
|
|
"""Load ground truth from rowing_results.csv.
|
|
|
|
Returns dict mapping image base name (e.g. 'IMG_5413') to {time, distance}.
|
|
"""
|
|
truth = {}
|
|
with open(csv_path) as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
# image column is like "IMG_5413.JPEG"
|
|
base = row["image"].split(".")[0]
|
|
truth[base] = {
|
|
"time": row["time"],
|
|
"distance_m": int(row["distance_m"]),
|
|
}
|
|
return truth
|
|
|
|
|
|
def image_base_name(filename):
|
|
"""Extract base image name from cropped filename.
|
|
|
|
e.g. 'IMG_5413_screen (1).jpg' -> 'IMG_5413'
|
|
"""
|
|
# Remove _screen suffix and any (N) copy indicator
|
|
name = Path(filename).stem
|
|
name = re.sub(r"_screen.*", "", name)
|
|
return name
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Extract rowing data from cropped screen images using Tesseract OCR"
|
|
)
|
|
parser.add_argument("--image", help="Path to a single cropped image")
|
|
parser.add_argument("--dir", help="Path to directory of cropped images")
|
|
parser.add_argument(
|
|
"--validate",
|
|
help="Path to rowing_results.csv for validation",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not args.image and not args.dir:
|
|
parser.error("Provide --image or --dir")
|
|
if args.image and args.dir:
|
|
parser.error("--image and --dir are mutually exclusive")
|
|
|
|
IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
|
|
|
|
if args.dir:
|
|
dir_path = Path(args.dir)
|
|
image_paths = sorted(
|
|
p for p in dir_path.iterdir() if p.suffix.lower() in IMAGE_EXTS
|
|
)
|
|
if not image_paths:
|
|
print(f"No images found in {args.dir}")
|
|
sys.exit(1)
|
|
else:
|
|
image_paths = [Path(args.image)]
|
|
|
|
# Load ground truth if provided
|
|
truth = {}
|
|
if args.validate:
|
|
truth = load_ground_truth(args.validate)
|
|
|
|
# Group images by base name (multiple crops per original photo)
|
|
# We only need to extract once per original photo - pick the best crop
|
|
from collections import defaultdict
|
|
|
|
by_base = defaultdict(list)
|
|
for p in image_paths:
|
|
base = image_base_name(p.name)
|
|
by_base[base].append(p)
|
|
|
|
correct_time = 0
|
|
correct_dist = 0
|
|
total = 0
|
|
|
|
# Per-method accuracy tracking: method -> {time_correct, time_wrong, time_abstain, ...}
|
|
method_stats = defaultdict(lambda: {
|
|
"time_correct": 0, "time_wrong": 0, "time_abstain": 0,
|
|
"dist_correct": 0, "dist_wrong": 0, "dist_abstain": 0,
|
|
})
|
|
|
|
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
|
|
|
|
for base, paths in sorted(by_base.items()):
|
|
gt = truth.get(base)
|
|
expected_time = gt["time"] if gt else None
|
|
expected_dist = str(gt["distance_m"]) if gt else None
|
|
|
|
# Try each crop, collect all votes
|
|
all_time_votes = Counter()
|
|
all_dist_votes = Counter()
|
|
all_labeled_texts = []
|
|
|
|
for p in paths:
|
|
result = extract_screen_data(str(p))
|
|
if result:
|
|
all_time_votes.update(result["time_votes"])
|
|
all_dist_votes.update(result["distance_votes"])
|
|
all_labeled_texts.extend(result["labeled_texts"])
|
|
|
|
best_time = pick_best(all_time_votes)
|
|
best_dist = pick_best(all_dist_votes)
|
|
|
|
total += 1
|
|
time_ok = best_time == expected_time if gt else None
|
|
dist_ok = best_dist == expected_dist if gt else None
|
|
|
|
if time_ok:
|
|
correct_time += 1
|
|
if dist_ok:
|
|
correct_dist += 1
|
|
|
|
# Track per-method accuracy against ground truth
|
|
if gt:
|
|
times_by_method = extract_times_by_method(all_labeled_texts)
|
|
dists_by_method = extract_distances_by_method(all_labeled_texts)
|
|
|
|
for method, times in times_by_method.items():
|
|
if not times:
|
|
method_stats[method]["time_abstain"] += 1
|
|
elif expected_time in times:
|
|
method_stats[method]["time_correct"] += 1
|
|
else:
|
|
method_stats[method]["time_wrong"] += 1
|
|
|
|
for method, dists in dists_by_method.items():
|
|
if not dists:
|
|
method_stats[method]["dist_abstain"] += 1
|
|
elif expected_dist in dists:
|
|
method_stats[method]["dist_correct"] += 1
|
|
else:
|
|
method_stats[method]["dist_wrong"] += 1
|
|
|
|
# Display
|
|
status_t = ""
|
|
status_d = ""
|
|
if gt:
|
|
status_t = " OK" if time_ok else f" EXPECTED {expected_time}"
|
|
status_d = " OK" if dist_ok else f" EXPECTED {expected_dist}"
|
|
|
|
print(f" {base}:")
|
|
print(
|
|
f" Time: {best_time or '???'}{status_t} (votes: {all_time_votes.most_common(3)})"
|
|
)
|
|
print(
|
|
f" Distance: {best_dist or '???'}{status_d} (votes: {all_dist_votes.most_common(3)})"
|
|
)
|
|
|
|
if truth:
|
|
print(f"\nAccuracy:")
|
|
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
|
|
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
|
|
|
|
if method_stats:
|
|
print(f"\nMethod Accuracy:")
|
|
print(
|
|
f" {'Method':<20s} {'Time OK':>8s} {'Time Wrong':>11s} {'Time Abstain':>13s}"
|
|
f" {'Dist OK':>8s} {'Dist Wrong':>11s} {'Dist Abstain':>13s}"
|
|
)
|
|
for method in sorted(method_stats):
|
|
s = method_stats[method]
|
|
print(
|
|
f" {method:<20s} {s['time_correct']:>8d} {s['time_wrong']:>11d}"
|
|
f" {s['time_abstain']:>13d} {s['dist_correct']:>8d}"
|
|
f" {s['dist_wrong']:>11d} {s['dist_abstain']:>13d}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|