Optimise parameters
This commit is contained in:
92
README.md
Normal file
92
README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Rowing Stats
|
||||
|
||||
Extract workout data from photos of Concept 2 PM5 rowing machine displays using computer vision and Claude's vision API.
|
||||
|
||||
## How It Works
|
||||
|
||||
Photos go through a three-stage pipeline:
|
||||
|
||||
```
|
||||
photos/ → crop_to_screen.py → screen_classifier.py → extract_screen_data.py → rowing_results.csv
|
||||
```
|
||||
|
||||
1. **Screen Detection** (`crop_to_screen.py`) — Finds and perspective-corrects the LCD screen region using OpenCV edge detection, contour filtering, and morphological operations. Candidates are scored by `edge_density × area × rectangularity`.
|
||||
2. **Classification** (`screen_classifier.py`) — Filters out non-rowing images. Supports a rule-based feature scorer (no training needed) and a 4-layer CNN with batch norm.
|
||||
3. **Data Extraction** (`extract_screen_data.py`) — Extracts time and distance from cropped screen images using Tesseract OCR with multiple preprocessing variants (CLAHE, thresholding, scaling) and majority-vote extraction.
|
||||
|
||||
There is also `extract_rowing_data.py`, which uses Claude Haiku's vision API instead of Tesseract for data extraction. This serves as a reference/test for validating OCR accuracy but is more expensive to run due to API costs.
|
||||
|
||||
There is also an Optuna-based hyperparameter tuner (`optimize_crop.py`) for the screen detection parameters.
|
||||
|
||||
## Setup
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
pip install anthropic torch torchvision opencv-python Pillow numpy optuna
|
||||
```
|
||||
|
||||
### API Key
|
||||
|
||||
Create a `.env` file with your Anthropic API key:
|
||||
|
||||
```
|
||||
ANTHROPIC_API_KEY=sk-ant-...
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Full pipeline
|
||||
|
||||
```bash
|
||||
# 1. Crop screens from photos
|
||||
python crop_to_screen.py photos/ cropped/
|
||||
|
||||
# 2. Classify — keep only rowing displays
|
||||
python screen_classifier.py predict --dir cropped/
|
||||
|
||||
# 3. Extract workout data via Tesseract OCR
|
||||
python extract_screen_data.py --dir cropped/
|
||||
|
||||
# 3b. (Test) Extract via Claude API — more expensive, useful for validating OCR accuracy
|
||||
python extract_rowing_data.py --dir photos/
|
||||
```
|
||||
|
||||
### Individual commands
|
||||
|
||||
```bash
|
||||
# Classify a single image (feature-based or CNN)
|
||||
python screen_classifier.py predict --image path/to/img.jpg
|
||||
python screen_classifier.py predict --image path/to/img.jpg --mode cnn
|
||||
|
||||
# Extract data from a single image (Tesseract OCR)
|
||||
python extract_screen_data.py --image path/to/img.jpg
|
||||
|
||||
# Extract data from a single image (Claude API — for testing/validation)
|
||||
python extract_rowing_data.py --image path/to/img.jpg
|
||||
|
||||
# Train the CNN classifier
|
||||
python screen_classifier.py train --data-dir train/
|
||||
|
||||
# Optimize crop detection parameters
|
||||
python optimize_crop.py --n-trials 300 --photos-dir photos/
|
||||
```
|
||||
|
||||
## Training Data
|
||||
|
||||
The CNN classifier trains on labeled images in `train/`:
|
||||
|
||||
- `train/0/` — non-rowing images (negatives)
|
||||
- `train/1/` — rowing display images (positives)
|
||||
|
||||
The trained model is saved as `screen_classifier_model.pth`.
|
||||
|
||||
## Validation
|
||||
|
||||
Extracted metrics are validated against sensible bounds:
|
||||
|
||||
| Metric | Min | Max |
|
||||
| -------- | --------- | --------- |
|
||||
| Distance | 100 m | 100,000 m |
|
||||
| Time | 30 s | 2 hrs |
|
||||
| Pace | 1:20/500m | 2:30/500m |
|
||||
@@ -48,7 +48,8 @@ def ocr_image(img, psm=6, whitelist="0123456789:."):
|
||||
def preprocess_and_ocr(img_path):
|
||||
"""Generate multiple preprocessed variants and OCR each one.
|
||||
|
||||
Returns a list of OCR text results from different preprocessing pipelines.
|
||||
Returns a list of (label, text) tuples from different preprocessing pipelines.
|
||||
Labels like 's2_clahe_t120', 's3_otsu', 's4_otsu_inv'.
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
@@ -57,35 +58,37 @@ def preprocess_and_ocr(img_path):
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
results = []
|
||||
|
||||
for scale in [2, 3, 4]:
|
||||
for scale in [2, 3]:
|
||||
scaled = cv2.resize(
|
||||
gray, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
|
||||
# CLAHE + fixed thresholds
|
||||
# CLAHE + fixed thresholds (only t120/t130 at scale 2 are useful)
|
||||
if scale == 2:
|
||||
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
||||
enhanced = clahe.apply(scaled)
|
||||
for thresh_val in [120, 130, 140, 150, 160]:
|
||||
for thresh_val in [120, 130]:
|
||||
_, thresh = cv2.threshold(enhanced, thresh_val, 255, cv2.THRESH_BINARY)
|
||||
results.append(ocr_image(thresh))
|
||||
results.append((f"s{scale}_clahe_t{thresh_val}", ocr_image(thresh)))
|
||||
|
||||
# Otsu
|
||||
blur = cv2.GaussianBlur(scaled, (3, 3), 0)
|
||||
_, otsu = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
results.append(ocr_image(otsu))
|
||||
results.append((f"s{scale}_otsu", ocr_image(otsu)))
|
||||
|
||||
# Inverted Otsu
|
||||
results.append(ocr_image(255 - otsu))
|
||||
results.append((f"s{scale}_otsu_inv", ocr_image(255 - otsu)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_times(texts):
|
||||
def extract_times(labeled_texts):
|
||||
"""Extract time candidates from OCR texts using majority voting.
|
||||
|
||||
Supports both H:MM:SS (hour+ workouts) and MM:SS formats.
|
||||
All times are normalized to MM:SS format for downstream consistency.
|
||||
"""
|
||||
texts = [t for _, t in labeled_texts]
|
||||
time_counts = Counter()
|
||||
for text in texts:
|
||||
# Match H:MM:SS patterns first (e.g. 1:00:08)
|
||||
@@ -117,8 +120,9 @@ def extract_times(texts):
|
||||
return time_counts
|
||||
|
||||
|
||||
def extract_distances(texts):
|
||||
def extract_distances(labeled_texts):
|
||||
"""Extract distance candidates from OCR texts using majority voting."""
|
||||
texts = [t for _, t in labeled_texts]
|
||||
dist_counts = Counter()
|
||||
for text in texts:
|
||||
# Match 4-5 digit numbers (plausible rowing distances: 1000-99999m)
|
||||
@@ -130,6 +134,50 @@ def extract_distances(texts):
|
||||
return dist_counts
|
||||
|
||||
|
||||
def extract_times_by_method(labeled_texts):
|
||||
"""Extract time candidates grouped by method label.
|
||||
|
||||
Returns dict mapping method label to set of extracted time strings.
|
||||
"""
|
||||
by_method = {}
|
||||
for label, text in labeled_texts:
|
||||
times = set()
|
||||
hms_matches = re.findall(r"(\d{1,2}):(\d{2}):(\d{2})", text)
|
||||
consumed = set()
|
||||
for h, m, s in hms_matches:
|
||||
hours, mins, secs = int(h), int(m), int(s)
|
||||
total_mins = hours * 60 + mins
|
||||
if 10 <= total_mins <= 120 and 0 <= secs <= 59:
|
||||
times.add(f"{total_mins}:{secs:02d}")
|
||||
for match in re.finditer(r"\d{1,2}:\d{2}:\d{2}", text):
|
||||
consumed.add((match.start(), match.end()))
|
||||
for match in re.finditer(r"(\d{2}):(\d{2})", text):
|
||||
if any(match.start() >= cs and match.end() <= ce for cs, ce in consumed):
|
||||
continue
|
||||
mins, secs = int(match.group(1)), int(match.group(2))
|
||||
if 10 <= mins <= 120 and 0 <= secs <= 59:
|
||||
times.add(f"{mins}:{secs:02d}")
|
||||
by_method[label] = times
|
||||
return by_method
|
||||
|
||||
|
||||
def extract_distances_by_method(labeled_texts):
|
||||
"""Extract distance candidates grouped by method label.
|
||||
|
||||
Returns dict mapping method label to set of extracted distance strings.
|
||||
"""
|
||||
by_method = {}
|
||||
for label, text in labeled_texts:
|
||||
dists = set()
|
||||
matches = re.findall(r"\b(\d{4,5})\b", text)
|
||||
for m in matches:
|
||||
val = int(m)
|
||||
if 1000 <= val <= 50000:
|
||||
dists.add(m)
|
||||
by_method[label] = dists
|
||||
return by_method
|
||||
|
||||
|
||||
def pick_best(counts, expected=None):
|
||||
"""Pick the best candidate from vote counts. Prefer expected value if present."""
|
||||
if not counts:
|
||||
@@ -150,12 +198,12 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
||||
Returns:
|
||||
dict with keys: time, distance, time_votes, distance_votes
|
||||
"""
|
||||
texts = preprocess_and_ocr(img_path)
|
||||
if not texts:
|
||||
labeled_texts = preprocess_and_ocr(img_path)
|
||||
if not labeled_texts:
|
||||
return None
|
||||
|
||||
time_counts = extract_times(texts)
|
||||
dist_counts = extract_distances(texts)
|
||||
time_counts = extract_times(labeled_texts)
|
||||
dist_counts = extract_distances(labeled_texts)
|
||||
|
||||
best_time = pick_best(time_counts, expected_time)
|
||||
best_distance = pick_best(dist_counts, expected_distance)
|
||||
@@ -165,6 +213,7 @@ def extract_screen_data(img_path, expected_time=None, expected_distance=None):
|
||||
"distance": int(best_distance) if best_distance else None,
|
||||
"time_votes": time_counts,
|
||||
"distance_votes": dist_counts,
|
||||
"labeled_texts": labeled_texts,
|
||||
}
|
||||
|
||||
|
||||
@@ -245,6 +294,12 @@ def main():
|
||||
correct_dist = 0
|
||||
total = 0
|
||||
|
||||
# Per-method accuracy tracking: method -> {time_correct, time_wrong, time_abstain, ...}
|
||||
method_stats = defaultdict(lambda: {
|
||||
"time_correct": 0, "time_wrong": 0, "time_abstain": 0,
|
||||
"dist_correct": 0, "dist_wrong": 0, "dist_abstain": 0,
|
||||
})
|
||||
|
||||
print(f"Processing {len(by_base)} unique images from {len(image_paths)} crops...\n")
|
||||
|
||||
for base, paths in sorted(by_base.items()):
|
||||
@@ -255,12 +310,14 @@ def main():
|
||||
# Try each crop, collect all votes
|
||||
all_time_votes = Counter()
|
||||
all_dist_votes = Counter()
|
||||
all_labeled_texts = []
|
||||
|
||||
for p in paths:
|
||||
result = extract_screen_data(str(p))
|
||||
if result:
|
||||
all_time_votes.update(result["time_votes"])
|
||||
all_dist_votes.update(result["distance_votes"])
|
||||
all_labeled_texts.extend(result["labeled_texts"])
|
||||
|
||||
best_time = pick_best(all_time_votes)
|
||||
best_dist = pick_best(all_dist_votes)
|
||||
@@ -274,6 +331,27 @@ def main():
|
||||
if dist_ok:
|
||||
correct_dist += 1
|
||||
|
||||
# Track per-method accuracy against ground truth
|
||||
if gt:
|
||||
times_by_method = extract_times_by_method(all_labeled_texts)
|
||||
dists_by_method = extract_distances_by_method(all_labeled_texts)
|
||||
|
||||
for method, times in times_by_method.items():
|
||||
if not times:
|
||||
method_stats[method]["time_abstain"] += 1
|
||||
elif expected_time in times:
|
||||
method_stats[method]["time_correct"] += 1
|
||||
else:
|
||||
method_stats[method]["time_wrong"] += 1
|
||||
|
||||
for method, dists in dists_by_method.items():
|
||||
if not dists:
|
||||
method_stats[method]["dist_abstain"] += 1
|
||||
elif expected_dist in dists:
|
||||
method_stats[method]["dist_correct"] += 1
|
||||
else:
|
||||
method_stats[method]["dist_wrong"] += 1
|
||||
|
||||
# Display
|
||||
status_t = ""
|
||||
status_d = ""
|
||||
@@ -294,6 +372,20 @@ def main():
|
||||
print(f" Time: {correct_time}/{total} ({correct_time / total * 100:.0f}%)")
|
||||
print(f" Distance: {correct_dist}/{total} ({correct_dist / total * 100:.0f}%)")
|
||||
|
||||
if method_stats:
|
||||
print(f"\nMethod Accuracy:")
|
||||
print(
|
||||
f" {'Method':<20s} {'Time OK':>8s} {'Time Wrong':>11s} {'Time Abstain':>13s}"
|
||||
f" {'Dist OK':>8s} {'Dist Wrong':>11s} {'Dist Abstain':>13s}"
|
||||
)
|
||||
for method in sorted(method_stats):
|
||||
s = method_stats[method]
|
||||
print(
|
||||
f" {method:<20s} {s['time_correct']:>8d} {s['time_wrong']:>11d}"
|
||||
f" {s['time_abstain']:>13d} {s['dist_correct']:>8d}"
|
||||
f" {s['dist_wrong']:>11d} {s['dist_abstain']:>13d}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user