Optimise parameters

This commit is contained in:
2026-03-16 15:20:59 +00:00
parent ee8fa0960d
commit c447cc41ea
2 changed files with 200 additions and 16 deletions

92
README.md Normal file
View File

@@ -0,0 +1,92 @@
# Rowing Stats
Extract workout data from photos of Concept 2 PM5 rowing machine displays using computer vision and Claude's vision API.
## How It Works
Photos go through a three-stage pipeline:
```
photos/ → crop_to_screen.py → screen_classifier.py → extract_screen_data.py → rowing_results.csv
```
1. **Screen Detection** (`crop_to_screen.py`) — Finds and perspective-corrects the LCD screen region using OpenCV edge detection, contour filtering, and morphological operations. Candidates are scored by `edge_density × area × rectangularity`.
2. **Classification** (`screen_classifier.py`) — Filters out non-rowing images. Supports a rule-based feature scorer (no training needed) and a 4-layer CNN with batch norm.
3. **Data Extraction** (`extract_screen_data.py`) — Extracts time and distance from cropped screen images using Tesseract OCR with multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote extraction.
There is also `extract_rowing_data.py`, which uses Claude Haiku's vision API instead of Tesseract for data extraction. This serves as a reference/test for validating OCR accuracy but is more expensive to run due to API costs.
There is also an Optuna-based hyperparameter tuner (`optimize_crop.py`) for the screen detection parameters.
## Setup
### Dependencies
```
pip install anthropic torch torchvision opencv-python Pillow numpy optuna
```
### API Key
Create a `.env` file with your Anthropic API key:
```
ANTHROPIC_API_KEY=sk-ant-...
```
## Usage
### Full pipeline
```bash
# 1. Crop screens from photos
python crop_to_screen.py photos/ cropped/
# 2. Classify — keep only rowing displays
python screen_classifier.py predict --dir cropped/
# 3. Extract workout data via Tesseract OCR
python extract_screen_data.py --dir cropped/
# 3b. (Test) Extract via Claude API — more expensive, useful for validating OCR accuracy
python extract_rowing_data.py --dir photos/
```
### Individual commands
```bash
# Classify a single image (feature-based or CNN)
python screen_classifier.py predict --image path/to/img.jpg
python screen_classifier.py predict --image path/to/img.jpg --mode cnn
# Extract data from a single image (Tesseract OCR)
python extract_screen_data.py --image path/to/img.jpg
# Extract data from a single image (Claude API — for testing/validation)
python extract_rowing_data.py --image path/to/img.jpg
# Train the CNN classifier
python screen_classifier.py train --data-dir train/
# Optimize crop detection parameters
python optimize_crop.py --n-trials 300 --photos-dir photos/
```
## Training Data
The CNN classifier trains on labeled images in `train/`:
- `train/0/` — non-rowing images (negatives)
- `train/1/` — rowing display images (positives)
The trained model is saved as `screen_classifier_model.pth`.
## Validation
Extracted metrics are validated against sensible bounds:
| Metric | Min | Max |
| -------- | --------- | --------- |
| Distance | 100 m | 100,000 m |
| Time | 30 s | 2 hrs |
| Pace | 1:20/500m | 2:30/500m |

View File

@@ -48,7 +48,8 @@ def ocr_image(img, psm=6, whitelist="0123456789:."):
def preprocess_and_ocr(img_path):
"""Generate multiple preprocessed variants and OCR each one.
Returns a list of OCR text results from different preprocessing pipelines.
Returns a list of (label, text) tuples from different preprocessing pipelines.
Labels like 's2_clahe_t120', 's3_otsu', 's4_otsu_inv'.
"""
img = cv2.imread(img_path)
if img is None:
@@ -57,35 +58,37 @@ def preprocess_and_ocr(img_path):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
results = []
for scale in [2, 3, 4]:
for scale in [2, 3]:
scaled = cv2.resize(
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
)
# CLAHE + fixed thresholds
# CLAHE + fixed thresholds (only t120/t130 at scale 2 are useful)
if scale == 2:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(scaled)
for thresh_val in [120, 130, 140, 150, 160]:
for thresh_val in [120, 130]:
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
results.append(ocr_image(thresh))
results.append((f"s{scale}_clahe_t{thresh_val}", ocr_image(thresh)))
# Otsu
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
results.append(ocr_image(otsu))
results.append((f"s{scale}_otsu", ocr_image(otsu)))
# Inverted Otsu
results.append(ocr_image(255 - otsu))
results.append((f"s{scale}_otsu_inv", ocr_image(255 - otsu)))
return results
def extract_times(texts):
def extract_times(labeled_texts):
"""Extract time candidates from OCR texts using majority voting.
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
All times are normalized to MM:SS format for downstream consistency.
"""
texts = [t for _, t in labeled_texts]
time_counts = Counter()
for text in texts:
# Match H:MM:SS patterns first (e.g. 1:00:08)
@@ -117,8 +120,9 @@ def extract_times(texts):
return time_counts
def extract_distances(texts):
def extract_distances(labeled_texts):
"""Extract distance candidates from OCR texts using majority voting."""
texts = [t for _, t in labeled_texts]
dist_counts = Counter()
for text in texts:
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
@@ -130,6 +134,50 @@ def extract_distances(texts):
return dist_counts
def extract_times_by_method(labeled_texts):
"""Extract time candidates grouped by method label.
Returns dict mapping method label to set of extracted time strings.
"""
by_method = {}
for label, text in labeled_texts:
times = set()
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
consumed = set()
for h, m, s in hms_matches:
hours, mins, secs = int(h), int(m), int(s)
total_mins = hours * 60 + mins
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
times.add(f"{total_mins}:{secs:02d}")
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
consumed.add((match.start(), match.end()))
for match in re.finditer(r"(\d{2}):(\d{2})", text):
if any(match.start() >= cs and match.end() <= ce for cs, ce in consumed):
continue
mins, secs = int(match.group(1)), int(match.group(2))
if 10 <= mins <= 120 and 0 <= secs <= 59:
times.add(f"{mins}:{secs:02d}")
by_method[label] = times
return by_method
def extract_distances_by_method(labeled_texts):
"""Extract distance candidates grouped by method label.
Returns dict mapping method label to set of extracted distance strings.
"""
by_method = {}
for label, text in labeled_texts:
dists = set()
matches = re.findall(r"\b(\d{4,5})\b", text)
for m in matches:
val = int(m)
if 1000 <= val <= 50000:
dists.add(m)
by_method[label] = dists
return by_method
def pick_best(counts, expected=None):
"""Pick the best candidate from vote counts. Prefer expected value if present."""
if not counts:
@@ -150,12 +198,12 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
Returns:
dict with keys: time, distance, time_votes, distance_votes
"""
texts = preprocess_and_ocr(img_path)
if not texts:
labeled_texts = preprocess_and_ocr(img_path)
if not labeled_texts:
return None
time_counts = extract_times(texts)
dist_counts = extract_distances(texts)
time_counts = extract_times(labeled_texts)
dist_counts = extract_distances(labeled_texts)
best_time = pick_best(time_counts, expected_time)
best_distance = pick_best(dist_counts, expected_distance)
@@ -165,6 +213,7 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
"distance": int(best_distance) if best_distance else None,
"time_votes": time_counts,
"distance_votes": dist_counts,
"labeled_texts": labeled_texts,
}
@@ -245,6 +294,12 @@ def main():
correct_dist = 0
total = 0
# Per-method accuracy tracking: method -> {time_correct, time_wrong, time_abstain, ...}
method_stats = defaultdict(lambda: {
"time_correct": 0, "time_wrong": 0, "time_abstain": 0,
"dist_correct": 0, "dist_wrong": 0, "dist_abstain": 0,
})
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
for base, paths in sorted(by_base.items()):
@@ -255,12 +310,14 @@ def main():
# Try each crop, collect all votes
all_time_votes = Counter()
all_dist_votes = Counter()
all_labeled_texts = []
for p in paths:
result = extract_screen_data(str(p))
if result:
all_time_votes.update(result["time_votes"])
all_dist_votes.update(result["distance_votes"])
all_labeled_texts.extend(result["labeled_texts"])
best_time = pick_best(all_time_votes)
best_dist = pick_best(all_dist_votes)
@@ -274,6 +331,27 @@ def main():
if dist_ok:
correct_dist += 1
# Track per-method accuracy against ground truth
if gt:
times_by_method = extract_times_by_method(all_labeled_texts)
dists_by_method = extract_distances_by_method(all_labeled_texts)
for method, times in times_by_method.items():
if not times:
method_stats[method]["time_abstain"] += 1
elif expected_time in times:
method_stats[method]["time_correct"] += 1
else:
method_stats[method]["time_wrong"] += 1
for method, dists in dists_by_method.items():
if not dists:
method_stats[method]["dist_abstain"] += 1
elif expected_dist in dists:
method_stats[method]["dist_correct"] += 1
else:
method_stats[method]["dist_wrong"] += 1
# Display
status_t = ""
status_d = ""
@@ -294,6 +372,20 @@ def main():
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
if method_stats:
print(f"\nMethod Accuracy:")
print(
f" {'Method':<20s} {'Time OK':>8s} {'Time Wrong':>11s} {'Time Abstain':>13s}"
f" {'Dist OK':>8s} {'Dist Wrong':>11s} {'Dist Abstain':>13s}"
)
for method in sorted(method_stats):
s = method_stats[method]
print(
f" {method:<20s} {s['time_correct']:>8d} {s['time_wrong']:>11d}"
f" {s['time_abstain']:>13d} {s['dist_correct']:>8d}"
f" {s['dist_wrong']:>11d} {s['dist_abstain']:>13d}"
)
if __name__ == "__main__":
main()