#!/usr/bin/env python3
"""
Extract test sequences from dataset TSV and write to FASTA format.

Usage:
    python3 extract_test_seqs.py --input dataset.tsv --output test_seqs.fasta

Columns expected: id, seq, target, class, binary_label, split
Extracts rows where split == "test"
"""

import argparse
import csv
import sys

def extract_test_sequences(input_file, output_file, split_value="test"):
    extracted = 0
    skipped = 0

    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        reader = csv.DictReader(infile, delimiter='\t')

        # Check expected columns are present
        expected = {"id", "seq", "split"}
        if not expected.issubset(set(reader.fieldnames)):
            print(f"ERROR: Missing expected columns. Found: {reader.fieldnames}", file=sys.stderr)
            sys.exit(1)

        for row in reader:
            if row["split"].strip() != split_value:
                skipped += 1
                continue

            seq_id = row["id"].strip()
            seq    = row["seq"].strip()

            if not seq:
                print(f"  WARNING: Empty sequence for {seq_id}, skipping.", file=sys.stderr)
                continue

            outfile.write(f">{seq_id}\n{seq}\n")
            extracted += 1

    print(f"Done.")
    print(f"  Extracted: {extracted:,} {split_value} sequences")
    print(f"  Skipped:   {skipped:,} non-{split_value} sequences")
    print(f"  Output:    {output_file}")


def main():
    parser = argparse.ArgumentParser(description="Extract test sequences from TSV to FASTA")
    parser.add_argument("--input",  "-i", required=True, help="Input TSV file (dataset.tsv)")
    parser.add_argument("--output", "-o", default="test_sequences.fasta", help="Output FASTA file")
    parser.add_argument("--split",  "-s", default="test",
                        help="Which split to extract: test, train, val (default: test)")
    args = parser.parse_args()

    print(f"Reading {args.input}...")
    print(f"Extracting split='{args.split}' sequences...")
    extract_test_sequences(args.input, args.output, split_value=args.split)


if __name__ == "__main__":
    main()