Optimise parameters
This commit is contained in:
92
README.md
Normal file
92
README.md
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# Rowing Stats
|
||||||
|
|
||||||
|
Extract workout data from photos of Concept 2 PM5 rowing machine displays using computer vision and Claude's vision API.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
Photos go through a three-stage pipeline:
|
||||||
|
|
||||||
|
```
|
||||||
|
photos/ → crop_to_screen.py → screen_classifier.py → extract_screen_data.py → rowing_results.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **Screen Detection** (`crop_to_screen.py`) — Finds and perspective-corrects the LCD screen region using OpenCV edge detection, contour filtering, and morphological operations. Candidates are scored by `edge_density × area × rectangularity`.
|
||||||
|
2. **Classification** (`screen_classifier.py`) — Filters out non-rowing images. Supports a rule-based feature scorer (no training needed) and a 4-layer CNN with batch norm.
|
||||||
|
3. **Data Extraction** (`extract_screen_data.py`) — Extracts time and distance from cropped screen images using Tesseract OCR with multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote extraction.
|
||||||
|
|
||||||
|
There is also `extract_rowing_data.py`, which uses Claude Haiku's vision API instead of Tesseract for data extraction. This serves as a reference/test for validating OCR accuracy but is more expensive to run due to API costs.
|
||||||
|
|
||||||
|
There is also an Optuna-based hyperparameter tuner (`optimize_crop.py`) for the screen detection parameters.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install anthropic torch torchvision opencv-python Pillow numpy optuna
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Key
|
||||||
|
|
||||||
|
Create a `.env` file with your Anthropic API key:
|
||||||
|
|
||||||
|
```
|
||||||
|
ANTHROPIC_API_KEY=sk-ant-...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Full pipeline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Crop screens from photos
|
||||||
|
python crop_to_screen.py photos/ cropped/
|
||||||
|
|
||||||
|
# 2. Classify — keep only rowing displays
|
||||||
|
python screen_classifier.py predict --dir cropped/
|
||||||
|
|
||||||
|
# 3. Extract workout data via Tesseract OCR
|
||||||
|
python extract_screen_data.py --dir cropped/
|
||||||
|
|
||||||
|
# 3b. (Test) Extract via Claude API — more expensive, useful for validating OCR accuracy
|
||||||
|
python extract_rowing_data.py --dir photos/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Individual commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Classify a single image (feature-based or CNN)
|
||||||
|
python screen_classifier.py predict --image path/to/img.jpg
|
||||||
|
python screen_classifier.py predict --image path/to/img.jpg --mode cnn
|
||||||
|
|
||||||
|
# Extract data from a single image (Tesseract OCR)
|
||||||
|
python extract_screen_data.py --image path/to/img.jpg
|
||||||
|
|
||||||
|
# Extract data from a single image (Claude API — for testing/validation)
|
||||||
|
python extract_rowing_data.py --image path/to/img.jpg
|
||||||
|
|
||||||
|
# Train the CNN classifier
|
||||||
|
python screen_classifier.py train --data-dir train/
|
||||||
|
|
||||||
|
# Optimize crop detection parameters
|
||||||
|
python optimize_crop.py --n-trials 300 --photos-dir photos/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Data
|
||||||
|
|
||||||
|
The CNN classifier trains on labeled images in `train/`:
|
||||||
|
|
||||||
|
- `train/0/` — non-rowing images (negatives)
|
||||||
|
- `train/1/` — rowing display images (positives)
|
||||||
|
|
||||||
|
The trained model is saved as `screen_classifier_model.pth`.
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
Extracted metrics are validated against sensible bounds:
|
||||||
|
|
||||||
|
| Metric | Min | Max |
|
||||||
|
| -------- | --------- | --------- |
|
||||||
|
| Distance | 100 m | 100,000 m |
|
||||||
|
| Time | 30 s | 2 hrs |
|
||||||
|
| Pace | 1:20/500m | 2:30/500m |
|
||||||
@@ -48,7 +48,8 @@ def ocr_image(img, psm=6, whitelist="0123456789:."):
|
|||||||
def preprocess_and_ocr(img_path):
|
def preprocess_and_ocr(img_path):
|
||||||
"""Generate multiple preprocessed variants and OCR each one.
|
"""Generate multiple preprocessed variants and OCR each one.
|
||||||
|
|
||||||
Returns a list of OCR text results from different preprocessing pipelines.
|
Returns a list of (label, text) tuples from different preprocessing pipelines.
|
||||||
|
Labels like 's2_clahe_t120', 's3_otsu', 's4_otsu_inv'.
|
||||||
"""
|
"""
|
||||||
img = cv2.imread(img_path)
|
img = cv2.imread(img_path)
|
||||||
if img is None:
|
if img is None:
|
||||||
@@ -57,35 +58,37 @@ def preprocess_and_ocr(img_path):
|
|||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for scale in [2, 3, 4]:
|
for scale in [2, 3]:
|
||||||
scaled = cv2.resize(
|
scaled = cv2.resize(
|
||||||
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
||||||
)
|
)
|
||||||
|
|
||||||
# CLAHE + fixed thresholds
|
# CLAHE + fixed thresholds (only t120/t130 at scale 2 are useful)
|
||||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
if scale == 2:
|
||||||
enhanced = clahe.apply(scaled)
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||||
for thresh_val in [120, 130, 140, 150, 160]:
|
enhanced = clahe.apply(scaled)
|
||||||
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
|
for thresh_val in [120, 130]:
|
||||||
results.append(ocr_image(thresh))
|
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
|
||||||
|
results.append((f"s{scale}_clahe_t{thresh_val}", ocr_image(thresh)))
|
||||||
|
|
||||||
# Otsu
|
# Otsu
|
||||||
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
|
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
|
||||||
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
results.append(ocr_image(otsu))
|
results.append((f"s{scale}_otsu", ocr_image(otsu)))
|
||||||
|
|
||||||
# Inverted Otsu
|
# Inverted Otsu
|
||||||
results.append(ocr_image(255 - otsu))
|
results.append((f"s{scale}_otsu_inv", ocr_image(255 - otsu)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def extract_times(texts):
|
def extract_times(labeled_texts):
|
||||||
"""Extract time candidates from OCR texts using majority voting.
|
"""Extract time candidates from OCR texts using majority voting.
|
||||||
|
|
||||||
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
|
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
|
||||||
All times are normalized to MM:SS format for downstream consistency.
|
All times are normalized to MM:SS format for downstream consistency.
|
||||||
"""
|
"""
|
||||||
|
texts = [t for _, t in labeled_texts]
|
||||||
time_counts = Counter()
|
time_counts = Counter()
|
||||||
for text in texts:
|
for text in texts:
|
||||||
# Match H:MM:SS patterns first (e.g. 1:00:08)
|
# Match H:MM:SS patterns first (e.g. 1:00:08)
|
||||||
@@ -117,8 +120,9 @@ def extract_times(texts):
|
|||||||
return time_counts
|
return time_counts
|
||||||
|
|
||||||
|
|
||||||
def extract_distances(texts):
|
def extract_distances(labeled_texts):
|
||||||
"""Extract distance candidates from OCR texts using majority voting."""
|
"""Extract distance candidates from OCR texts using majority voting."""
|
||||||
|
texts = [t for _, t in labeled_texts]
|
||||||
dist_counts = Counter()
|
dist_counts = Counter()
|
||||||
for text in texts:
|
for text in texts:
|
||||||
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
|
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
|
||||||
@@ -130,6 +134,50 @@ def extract_distances(texts):
|
|||||||
return dist_counts
|
return dist_counts
|
||||||
|
|
||||||
|
|
||||||
|
def extract_times_by_method(labeled_texts):
|
||||||
|
"""Extract time candidates grouped by method label.
|
||||||
|
|
||||||
|
Returns dict mapping method label to set of extracted time strings.
|
||||||
|
"""
|
||||||
|
by_method = {}
|
||||||
|
for label, text in labeled_texts:
|
||||||
|
times = set()
|
||||||
|
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
|
||||||
|
consumed = set()
|
||||||
|
for h, m, s in hms_matches:
|
||||||
|
hours, mins, secs = int(h), int(m), int(s)
|
||||||
|
total_mins = hours * 60 + mins
|
||||||
|
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
|
||||||
|
times.add(f"{total_mins}:{secs:02d}")
|
||||||
|
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
|
||||||
|
consumed.add((match.start(), match.end()))
|
||||||
|
for match in re.finditer(r"(\d{2}):(\d{2})", text):
|
||||||
|
if any(match.start() >= cs and match.end() <= ce for cs, ce in consumed):
|
||||||
|
continue
|
||||||
|
mins, secs = int(match.group(1)), int(match.group(2))
|
||||||
|
if 10 <= mins <= 120 and 0 <= secs <= 59:
|
||||||
|
times.add(f"{mins}:{secs:02d}")
|
||||||
|
by_method[label] = times
|
||||||
|
return by_method
|
||||||
|
|
||||||
|
|
||||||
|
def extract_distances_by_method(labeled_texts):
|
||||||
|
"""Extract distance candidates grouped by method label.
|
||||||
|
|
||||||
|
Returns dict mapping method label to set of extracted distance strings.
|
||||||
|
"""
|
||||||
|
by_method = {}
|
||||||
|
for label, text in labeled_texts:
|
||||||
|
dists = set()
|
||||||
|
matches = re.findall(r"\b(\d{4,5})\b", text)
|
||||||
|
for m in matches:
|
||||||
|
val = int(m)
|
||||||
|
if 1000 <= val <= 50000:
|
||||||
|
dists.add(m)
|
||||||
|
by_method[label] = dists
|
||||||
|
return by_method
|
||||||
|
|
||||||
|
|
||||||
def pick_best(counts, expected=None):
|
def pick_best(counts, expected=None):
|
||||||
"""Pick the best candidate from vote counts. Prefer expected value if present."""
|
"""Pick the best candidate from vote counts. Prefer expected value if present."""
|
||||||
if not counts:
|
if not counts:
|
||||||
@@ -150,12 +198,12 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
|||||||
Returns:
|
Returns:
|
||||||
dict with keys: time, distance, time_votes, distance_votes
|
dict with keys: time, distance, time_votes, distance_votes
|
||||||
"""
|
"""
|
||||||
texts = preprocess_and_ocr(img_path)
|
labeled_texts = preprocess_and_ocr(img_path)
|
||||||
if not texts:
|
if not labeled_texts:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
time_counts = extract_times(texts)
|
time_counts = extract_times(labeled_texts)
|
||||||
dist_counts = extract_distances(texts)
|
dist_counts = extract_distances(labeled_texts)
|
||||||
|
|
||||||
best_time = pick_best(time_counts, expected_time)
|
best_time = pick_best(time_counts, expected_time)
|
||||||
best_distance = pick_best(dist_counts, expected_distance)
|
best_distance = pick_best(dist_counts, expected_distance)
|
||||||
@@ -165,6 +213,7 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
|||||||
"distance": int(best_distance) if best_distance else None,
|
"distance": int(best_distance) if best_distance else None,
|
||||||
"time_votes": time_counts,
|
"time_votes": time_counts,
|
||||||
"distance_votes": dist_counts,
|
"distance_votes": dist_counts,
|
||||||
|
"labeled_texts": labeled_texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -245,6 +294,12 @@ def main():
|
|||||||
correct_dist = 0
|
correct_dist = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
|
# Per-method accuracy tracking: method -> {time_correct, time_wrong, time_abstain, ...}
|
||||||
|
method_stats = defaultdict(lambda: {
|
||||||
|
"time_correct": 0, "time_wrong": 0, "time_abstain": 0,
|
||||||
|
"dist_correct": 0, "dist_wrong": 0, "dist_abstain": 0,
|
||||||
|
})
|
||||||
|
|
||||||
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
|
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
|
||||||
|
|
||||||
for base, paths in sorted(by_base.items()):
|
for base, paths in sorted(by_base.items()):
|
||||||
@@ -255,12 +310,14 @@ def main():
|
|||||||
# Try each crop, collect all votes
|
# Try each crop, collect all votes
|
||||||
all_time_votes = Counter()
|
all_time_votes = Counter()
|
||||||
all_dist_votes = Counter()
|
all_dist_votes = Counter()
|
||||||
|
all_labeled_texts = []
|
||||||
|
|
||||||
for p in paths:
|
for p in paths:
|
||||||
result = extract_screen_data(str(p))
|
result = extract_screen_data(str(p))
|
||||||
if result:
|
if result:
|
||||||
all_time_votes.update(result["time_votes"])
|
all_time_votes.update(result["time_votes"])
|
||||||
all_dist_votes.update(result["distance_votes"])
|
all_dist_votes.update(result["distance_votes"])
|
||||||
|
all_labeled_texts.extend(result["labeled_texts"])
|
||||||
|
|
||||||
best_time = pick_best(all_time_votes)
|
best_time = pick_best(all_time_votes)
|
||||||
best_dist = pick_best(all_dist_votes)
|
best_dist = pick_best(all_dist_votes)
|
||||||
@@ -274,6 +331,27 @@ def main():
|
|||||||
if dist_ok:
|
if dist_ok:
|
||||||
correct_dist += 1
|
correct_dist += 1
|
||||||
|
|
||||||
|
# Track per-method accuracy against ground truth
|
||||||
|
if gt:
|
||||||
|
times_by_method = extract_times_by_method(all_labeled_texts)
|
||||||
|
dists_by_method = extract_distances_by_method(all_labeled_texts)
|
||||||
|
|
||||||
|
for method, times in times_by_method.items():
|
||||||
|
if not times:
|
||||||
|
method_stats[method]["time_abstain"] += 1
|
||||||
|
elif expected_time in times:
|
||||||
|
method_stats[method]["time_correct"] += 1
|
||||||
|
else:
|
||||||
|
method_stats[method]["time_wrong"] += 1
|
||||||
|
|
||||||
|
for method, dists in dists_by_method.items():
|
||||||
|
if not dists:
|
||||||
|
method_stats[method]["dist_abstain"] += 1
|
||||||
|
elif expected_dist in dists:
|
||||||
|
method_stats[method]["dist_correct"] += 1
|
||||||
|
else:
|
||||||
|
method_stats[method]["dist_wrong"] += 1
|
||||||
|
|
||||||
# Display
|
# Display
|
||||||
status_t = ""
|
status_t = ""
|
||||||
status_d = ""
|
status_d = ""
|
||||||
@@ -294,6 +372,20 @@ def main():
|
|||||||
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
|
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
|
||||||
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
|
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
|
||||||
|
|
||||||
|
if method_stats:
|
||||||
|
print(f"\nMethod Accuracy:")
|
||||||
|
print(
|
||||||
|
f" {'Method':<20s} {'Time OK':>8s} {'Time Wrong':>11s} {'Time Abstain':>13s}"
|
||||||
|
f" {'Dist OK':>8s} {'Dist Wrong':>11s} {'Dist Abstain':>13s}"
|
||||||
|
)
|
||||||
|
for method in sorted(method_stats):
|
||||||
|
s = method_stats[method]
|
||||||
|
print(
|
||||||
|
f" {method:<20s} {s['time_correct']:>8d} {s['time_wrong']:>11d}"
|
||||||
|
f" {s['time_abstain']:>13d} {s['dist_correct']:>8d}"
|
||||||
|
f" {s['dist_wrong']:>11d} {s['dist_abstain']:>13d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user