Skip to content

API Reference

This page is auto-generated from the source-code docstrings by mkdocstrings. For narrative examples and recipes see the hand-written Python API guide.

The reference is grouped by sub-package. Every public symbol exported from a sub-package is included; private helpers (those starting with _) are not.


Top-level

dta_gnn


End-to-End Pipeline (dta_gnn.training)

run_gnn_end_to_end

run_gnn_end_to_end(config: EndToEndConfig) -> EndToEndResult

Run the complete GNN training pipeline end-to-end.

Steps
  1. Parse UniProt accessions and map them to ChEMBL target IDs.
  2. Build a DTA dataset from ChEMBL using a scaffold split; save all required files (dataset.csv, compounds.csv, metadata.json) to a new timestamped run directory.
  3. Run a W&B Bayes hyperparameter sweep (validation set used for scoring).
  4. Train the final model with the best hyperparameters and log the run to the same W&B project.
  5. Return an :class:EndToEndResult with test metrics and per-step timings.

Parameters:

Name Type Description Default
config EndToEndConfig

Pipeline configuration.

required

Returns:

Type Description
EndToEndResult

class:EndToEndResult with all artifacts, metrics, and timing.

Raises:

Type Description
ValueError

If no ChEMBL targets can be resolved or the dataset is empty.

Source code in src/dta_gnn/training/end_to_end.py
def run_gnn_end_to_end(config: EndToEndConfig) -> EndToEndResult:
    """Run the complete GNN training pipeline end-to-end.

    Steps
    -----
    1. Parse UniProt accessions and map them to ChEMBL target IDs.
    2. Build a DTA dataset from ChEMBL using a scaffold split; save all
       required files (``dataset.csv``, ``compounds.csv``, ``metadata.json``)
       to a new timestamped run directory.
    3. Run a W&B Bayes hyperparameter sweep (validation set used for scoring).
    4. Train the final model with the best hyperparameters and log the run to
       the same W&B project.
    5. Return an :class:`EndToEndResult` with test metrics and per-step timings.

    Args:
        config: Pipeline configuration.

    Returns:
        :class:`EndToEndResult` with all artifacts, metrics, and timing.

    Raises:
        ValueError: If no ChEMBL targets can be resolved or the dataset is empty.
    """
    timings: dict = {}
    arch = config.architecture

    # ------------------------------------------------------------------
    # Step 1: UniProt → ChEMBL mapping
    # ------------------------------------------------------------------
    with _timed("uniprot_mapping", timings):
        accessions = parse_uniprot_accessions(config.uniprot_ids)
        logger.info("Parsed {} UniProt accession(s): {}", len(accessions), accessions)

        if config.sqlite_path:
            logger.info("Using SQLite source: {}", config.sqlite_path)
            mapping = map_uniprot_to_chembl_targets_sqlite(config.sqlite_path, accessions)
        else:
            logger.info("Using ChEMBL web API for UniProt→ChEMBL mapping")
            mapping = map_uniprot_to_chembl_targets_web(accessions)

        if mapping.unmapped:
            logger.warning(
                "No ChEMBL targets found for {} accession(s): {}",
                len(mapping.unmapped),
                mapping.unmapped,
            )

        if not mapping.resolved_target_chembl_ids:
            raise ValueError(
                f"No ChEMBL target IDs could be resolved from UniProt "
                f"accession(s): {accessions}. "
                "Check that the accessions are valid and present in ChEMBL."
            )

        target_chembl_ids = mapping.resolved_target_chembl_ids
        logger.info(
            "Resolved {} ChEMBL target(s): {}",
            len(target_chembl_ids),
            target_chembl_ids,
        )

    # Create run directory (not timed — instantaneous)
    run_dir = create_run_dir(runs_root=config.runs_root)
    logger.info("Run directory: {}", run_dir)

    # ------------------------------------------------------------------
    # Step 2: Dataset building with scaffold split
    # ------------------------------------------------------------------
    with _timed("dataset_build", timings):
        source_type: str = "sqlite" if config.sqlite_path else "web"
        pipeline = Pipeline(
            source_type=source_type,
            sqlite_path=config.sqlite_path,
        )

        dataset_path = run_dir / "dataset.csv"
        dataset_df = pipeline.build_dta(
            target_ids=target_chembl_ids,
            standard_types=config.standard_types,
            split_method="scaffold",
            test_size=config.test_size,
            val_size=config.val_size,
            output_path=str(dataset_path),
        )

        if dataset_df is None or dataset_df.empty:
            raise ValueError(
                "Dataset is empty after building. "
                "Verify that the target IDs have associated activity data "
                "in ChEMBL and that standard_types (if set) match available data."
            )

        # Save compounds.csv (required by train_gnn_on_run and optimize_gnn_wandb)
        compounds_df = (
            dataset_df[["molecule_chembl_id", "smiles"]]
            .drop_duplicates()
            .dropna(subset=["smiles"])
            .reset_index(drop=True)
        )
        compounds_path = run_dir / "compounds.csv"
        try:
            compounds_df.to_csv(compounds_path, index=False)
        except OSError as e:
            logger.error("Failed to write compounds.csv to {}: {}", compounds_path, e)
            raise

        # Compute split counts
        split_counts: dict = {}
        if "split" in dataset_df.columns:
            split_counts = {
                str(k): int(v)
                for k, v in dataset_df["split"].value_counts().items()
            }

        # Save metadata.json
        metadata = {
            "uniprot_ids": accessions,
            "target_chembl_ids": target_chembl_ids,
            "architecture": arch,
            "source_type": source_type,
            "split_method": "scaffold",
            "test_size": config.test_size,
            "val_size": config.val_size,
            "dataset_size": len(dataset_df),
            "split_counts": split_counts,
            "created_at": datetime.now(timezone.utc).isoformat(),
            "dta_gnn_version": __version__,
        }
        metadata_path = run_dir / "metadata.json"
        try:
            metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
        except OSError as e:
            logger.error("Failed to write metadata.json to {}: {}", metadata_path, e)
            raise

        logger.info(
            "Dataset built: {} rows total, splits: {}",
            len(dataset_df),
            split_counts,
        )

    # ------------------------------------------------------------------
    # Step 3: Hyperparameter search via W&B Bayes sweep
    # ------------------------------------------------------------------
    with _timed("hyperparameter_search", timings):
        hpo_config = HyperoptConfig(
            model_type="GNN",
            architecture=arch,
            n_trials=config.n_trials,
            device=config.device,
            # Defaults: search lr, embedding_dim, hidden_dim, num_layers, dropout
            optimize_lr=True,
            lr_min=config.lr_min,
            lr_max=config.lr_max,
            optimize_embedding_dim=True,
            embedding_dim_min=config.embedding_dim_min,
            embedding_dim_max=config.embedding_dim_max,
            optimize_hidden_dim=True,
            hidden_dim_min=config.hidden_dim_min,
            hidden_dim_max=config.hidden_dim_max,
            optimize_num_layers=True,
            num_layers_min=config.num_layers_min,
            num_layers_max=config.num_layers_max,
            optimize_dropout=True,
            dropout_min=config.dropout_min,
            dropout_max=config.dropout_max,
        )

        hyperopt_result = optimize_gnn_wandb(
            run_dir,
            config=hpo_config,
            project=config.wandb_project,
            entity=config.wandb_entity,
            api_key=config.wandb_api_key,
            sweep_name=f"gnn_{arch}_hpo",
        )

        logger.info(
            "HPO complete — best val R²: {:.4f} (trial {}), best params: {}",
            hyperopt_result.best_value,
            hyperopt_result.best_trial_number,
            hyperopt_result.best_params,
        )

    # ------------------------------------------------------------------
    # Step 4: Final model training with best hyperparameters
    # ------------------------------------------------------------------
    with _timed("final_training", timings):
        final_config = _best_params_to_gnn_config(hyperopt_result.best_params, config)

        wandb = _require_wandb()
        wandb_run = None
        try:
            if config.wandb_api_key and str(config.wandb_api_key).strip():
                wandb.login(key=str(config.wandb_api_key).strip(), relogin=True)

            wandb_run = wandb.init(
                project=config.wandb_project,
                entity=config.wandb_entity or None,
                name=f"gnn_{arch}_final",
                config={
                    "uniprot_ids": accessions,
                    "target_chembl_ids": target_chembl_ids,
                    "architecture": arch,
                    "dataset_size": len(dataset_df),
                    "split_counts": split_counts,
                    **{f: getattr(final_config, f) for f in final_config.__dataclass_fields__},
                },
                tags=["final_training", arch],
            )
            logger.info("W&B final training run: {}", wandb_run.url if wandb_run else "n/a")

            train_result = train_gnn_on_run(
                run_dir,
                config=final_config,
                wandb_run=wandb_run,
            )
        finally:
            if wandb_run is not None:
                wandb_run.finish()

        logger.info(
            "Final training complete — best epoch: {}, model: {}",
            train_result.best_epoch,
            train_result.model_path,
        )

    # ------------------------------------------------------------------
    # Step 5: Extract test metrics and report timing
    # ------------------------------------------------------------------
    test_metrics: dict = {}
    if train_result.metrics and isinstance(train_result.metrics.get("splits"), dict):
        test_metrics = train_result.metrics["splits"].get("test") or {}

    total_time = sum(timings.values())
    logger.info("=" * 60)
    logger.info(
        "End-to-end pipeline completed in {:.1f}s ({:.1f} min)",
        total_time,
        total_time / 60.0,
    )
    for step_name, step_time in timings.items():
        pct = 100.0 * step_time / total_time if total_time > 0 else 0.0
        logger.info("  {:<30s}  {:6.1f}s  ({:.0f}%)", step_name, step_time, pct)
    logger.info("Test metrics: {}", test_metrics)
    logger.info("Run directory: {}", run_dir)
    logger.info("=" * 60)

    return EndToEndResult(
        run_dir=run_dir,
        uniprot_ids=accessions,
        target_chembl_ids=target_chembl_ids,
        architecture=arch,
        dataset_size=len(dataset_df),
        train_size=split_counts.get("train", 0),
        val_size_actual=split_counts.get("val", 0),
        test_size_actual=split_counts.get("test", 0),
        hyperopt_result=hyperopt_result,
        train_result=train_result,
        test_metrics=test_metrics,
        timings=timings,
    )

EndToEndConfig dataclass

EndToEndConfig(uniprot_ids: str, architecture: Literal['gin', 'gcn', 'gat', 'sage', 'pna', 'transformer', 'tag', 'arma', 'cheb', 'supergat'] = 'gin', sqlite_path: str | None = None, standard_types: list[str] | None = None, test_size: float = 0.2, val_size: float = 0.1, wandb_project: str = 'dta_gnn', wandb_entity: str | None = None, wandb_api_key: str | None = None, n_trials: int = 20, lr_min: float = 1e-05, lr_max: float = 0.01, embedding_dim_min: int = 32, embedding_dim_max: int = 256, hidden_dim_min: int = 32, hidden_dim_max: int = 256, num_layers_min: int = 1, num_layers_max: int = 5, dropout_min: float = 0.0, dropout_max: float = 0.5, epochs: int = 30, batch_size: int = 64, runs_root: str = 'runs', device: str | None = None)

Configuration for the end-to-end GNN training pipeline.

Parameters:

Name Type Description Default
uniprot_ids str

One or more UniProt accessions (comma/space/semicolon-separated). Example: "P00533" or "P00533, P04637".

required
architecture Literal['gin', 'gcn', 'gat', 'sage', 'pna', 'transformer', 'tag', 'arma', 'cheb', 'supergat']

GNN architecture to train and tune.

'gin'
sqlite_path str | None

Path to a local ChEMBL SQLite database. When provided, all ChEMBL data is fetched from this file. When None, the ChEMBL web API is used as a fallback.

None
standard_types list[str] | None

Activity standard types to include (e.g. ["IC50", "Ki"]). None keeps all types.

None
test_size float

Fraction of data reserved for the test split.

0.2
val_size float

Fraction of data reserved for the validation split.

0.1
wandb_project str

W&B project name used for both the HPO sweep and the final training run.

'dta_gnn'
wandb_entity str | None

W&B entity (username or team). None uses the default.

None
wandb_api_key str | None

W&B API key. None relies on WANDB_API_KEY env variable or interactive login.

None
n_trials int

Number of W&B Bayes sweep trials for hyperparameter search.

20
lr_min float

Lower bound for learning-rate search (log-uniform).

1e-05
lr_max float

Upper bound for learning-rate search (log-uniform).

0.01
embedding_dim_min int

Lower bound for embedding dimension search.

32
embedding_dim_max int

Upper bound for embedding dimension search.

256
hidden_dim_min int

Lower bound for hidden dimension search.

32
hidden_dim_max int

Upper bound for hidden dimension search.

256
num_layers_min int

Lower bound for number of GNN layers search.

1
num_layers_max int

Upper bound for number of GNN layers search.

5
dropout_min float

Lower bound for dropout rate search.

0.0
dropout_max float

Upper bound for dropout rate search.

0.5
epochs int

Number of training epochs for the final model (HPO trials use fewer epochs internally).

30
batch_size int

Mini-batch size for both HPO and final training.

64
runs_root str

Root directory under which timestamped run directories are created (default: "runs").

'runs'
device str | None

PyTorch device string. None auto-detects (MPS > CUDA > CPU).

None

EndToEndResult dataclass

EndToEndResult(run_dir: Path, uniprot_ids: list[str], target_chembl_ids: list[str], architecture: str, dataset_size: int, train_size: int, val_size_actual: int, test_size_actual: int, hyperopt_result: HyperoptResult, train_result: GnnTrainResult, test_metrics: dict, timings: dict)

Result of a complete end-to-end GNN training run.

Attributes:

Name Type Description
run_dir Path

Path to the timestamped run directory holding all artifacts.

uniprot_ids list[str]

Validated UniProt accessions used as input.

target_chembl_ids list[str]

Resolved ChEMBL target IDs.

architecture str

GNN architecture that was trained.

dataset_size int

Total number of rows in the built dataset.

train_size int

Number of training rows.

val_size_actual int

Number of validation rows.

test_size_actual int

Number of test rows.

hyperopt_result HyperoptResult

Full result from optimize_gnn_wandb.

train_result GnnTrainResult

Full result from train_gnn_on_run.

test_metrics dict

Test-split metrics dict (r2, rmse, mae, …).

timings dict

Wall-clock time in seconds for each pipeline step.


Dataset Pipeline (dta_gnn.pipeline)

Pipeline

Pipeline(source_type: Literal['web', 'sqlite'] = 'web', sqlite_path: Optional[str] = None)
Source code in src/dta_gnn/pipeline.py
def __init__(
    self,
    source_type: Literal["web", "sqlite"] = "web",
    sqlite_path: Optional[str] = None,
):
    if source_type == "sqlite":
        if not sqlite_path:
            raise ValueError("sqlite_path required for sqlite source")
        # Lazy import so UI can start even if the ChEMBL Web API is down.
        from dta_gnn.io.sqlite_source import ChemblSQLiteSource

        self.source = ChemblSQLiteSource(sqlite_path)
    else:
        # Lazy import to avoid importing chembl_webresource_client at module import time.
        from dta_gnn.io.web_source import ChemblWebSource

        self.source = ChemblWebSource()

build_dta

build_dta(*, target_ids: Optional[List[str]] = None, molecule_ids: Optional[List[str]] = None, standard_types: Optional[List[str]] = None, split_method: str = 'random', output_path: Optional[str] = None, test_size: float = 0.2, val_size: float = 0.1, split_year: int = 2022, featurize: bool = False, progress_callback: Optional[callable] = None) -> pd.DataFrame

Build a DTA-style regression dataset.

The regression label is always pchembl_value (after optional cleaning). Target sequences/metadata are stored separately in self.last_targets_csv.

Source code in src/dta_gnn/pipeline.py
def build_dta(
    self,
    *,
    target_ids: Optional[List[str]] = None,
    molecule_ids: Optional[List[str]] = None,
    standard_types: Optional[List[str]] = None,
    split_method: str = "random",
    output_path: Optional[str] = None,
    test_size: float = 0.2,
    val_size: float = 0.1,
    split_year: int = 2022,
    featurize: bool = False,
    progress_callback: Optional[callable] = None,
) -> pd.DataFrame:
    """Build a DTA-style regression dataset.

    The regression label is always `pchembl_value` (after optional cleaning).
    Target sequences/metadata are stored separately in `self.last_targets_csv`.
    """

    logger.info(f"Starting DTA build for targets: {target_ids}")

    df_activities = self.source.fetch_activities(
        target_ids=target_ids,
        molecule_ids=molecule_ids,
        standard_types=standard_types,
        progress_callback=progress_callback,
    )

    if df_activities is None or len(df_activities) == 0:
        return pd.DataFrame()

    # Normalize/ensure pChEMBL.
    df_clean = standardize_activities(df_activities, convert_to_pchembl=True)
    df_agg = aggregate_duplicates(df_clean)

    dataset_df = df_agg.dropna(subset=["pchembl_value"]).copy()
    if dataset_df.empty:
        return pd.DataFrame()
    dataset_df["label"] = dataset_df["pchembl_value"]

    # Molecules — prefer SMILES already fetched with activities
    if "canonical_smiles" in dataset_df.columns:
        dataset_df = dataset_df.rename(columns={"canonical_smiles": "smiles"})
    else:
        mol_ids = dataset_df["molecule_chembl_id"].unique().tolist()
        df_mols = self.source.fetch_molecules(mol_ids)
        dataset_df = dataset_df.merge(df_mols, on="molecule_chembl_id", how="left")

    if featurize:
        dataset_df = calculate_morgan_fingerprints(
            dataset_df, radius=2, n_bits=2048
        )

    # Split
    if split_method == "random":
        dataset_df, _, _, _ = split_random(
            dataset_df, test_size=test_size, val_size=val_size
        )
    elif split_method == "scaffold":
        dataset_df = split_cold_drug_scaffold(
            dataset_df, test_size=test_size, val_size=val_size
        )
    elif split_method == "temporal":
        dataset_df = split_temporal(
            dataset_df, split_year=split_year, val_size=val_size
        )
    else:
        raise ValueError(f"Unknown split_method: {split_method}")

    # Targets metadata (saved separately).
    unique_targets = sorted(
        set(dataset_df["target_chembl_id"].dropna().astype(str).tolist())
    )
    targets_df = self.source.fetch_targets(unique_targets)
    tmp = tempfile.NamedTemporaryFile(
        prefix="targets_", suffix=".csv", delete=False
    )
    tmp.close()
    try:
        targets_df.to_csv(tmp.name, index=False)
    except OSError as e:
        logger.error("Failed to write targets CSV to {}: {}", tmp.name, e)
        raise
    self.last_targets_csv = tmp.name

    if output_path:
        try:
            dataset_df.to_csv(output_path, index=False)
        except OSError as e:
            logger.error("Failed to write dataset CSV to {}: {}", output_path, e)
            raise

    return dataset_df

Data Sources (dta_gnn.io)

ChEMBL sources

ChemblSource

Bases: ABC

Abstract base class for ChEMBL data sources.

fetch_activities abstractmethod

fetch_activities(target_ids: Optional[List[str]] = None, molecule_ids: Optional[List[str]] = None, standard_types: Optional[List[str]] = None, progress_callback: Optional[callable] = None) -> pd.DataFrame

Fetch activity data.

Returns a DataFrame with columns: - molecule_chembl_id - target_chembl_id - standard_type - standard_value - standard_units - standard_relation - pchembl_value (optional) - canonical_smiles (optional, avoids a separate fetch_molecules call)

Source code in src/dta_gnn/io/chembl_source.py
@abstractmethod
def fetch_activities(
    self,
    target_ids: Optional[List[str]] = None,
    molecule_ids: Optional[List[str]] = None,
    standard_types: Optional[List[str]] = None,
    progress_callback: Optional[callable] = None,
) -> pd.DataFrame:
    """
    Fetch activity data.

    Returns a DataFrame with columns:
    - molecule_chembl_id
    - target_chembl_id
    - standard_type
    - standard_value
    - standard_units
    - standard_relation
    - pchembl_value (optional)
    - canonical_smiles (optional, avoids a separate fetch_molecules call)
    """
    pass

fetch_molecules abstractmethod

fetch_molecules(molecule_ids: List[str]) -> pd.DataFrame

Fetch molecule structures.

Returns a DataFrame with columns: - molecule_chembl_id - smiles

Source code in src/dta_gnn/io/chembl_source.py
@abstractmethod
def fetch_molecules(self, molecule_ids: List[str]) -> pd.DataFrame:
    """
    Fetch molecule structures.

    Returns a DataFrame with columns:
    - molecule_chembl_id
    - smiles
    """
    pass

fetch_targets abstractmethod

fetch_targets(target_ids: List[str]) -> pd.DataFrame

Fetch target sequences.

Returns a DataFrame with columns: - target_chembl_id - sequence - organism

Source code in src/dta_gnn/io/chembl_source.py
@abstractmethod
def fetch_targets(self, target_ids: List[str]) -> pd.DataFrame:
    """
    Fetch target sequences.

    Returns a DataFrame with columns:
    - target_chembl_id
    - sequence
    - organism
    """
    pass

ChemblSQLiteSource

ChemblSQLiteSource(db_path: str)

Bases: ChemblSource

ChEMBL data source using a local SQLite database dump.

Source code in src/dta_gnn/io/sqlite_source.py
def __init__(self, db_path: str):
    import os

    if not os.path.exists(db_path):
        raise FileNotFoundError(f"ChEMBL SQLite file not found at: {db_path}")
    self.db_path = db_path

ChemblWebSource

Bases: ChemblSource

ChEMBL data source using the official Web Resource Client.

get_targets

get_targets(accession: Optional[str] = None) -> List[dict]

Get targets by UniProt accession.

Parameters:

Name Type Description Default
accession Optional[str]

UniProt accession (e.g., 'P00533' for EGFR)

None

Returns:

Type Description
List[dict]

List of target dictionaries with target_chembl_id

Source code in src/dta_gnn/io/web_source.py
def get_targets(self, accession: Optional[str] = None) -> List[dict]:
    """Get targets by UniProt accession.

    Args:
        accession: UniProt accession (e.g., 'P00533' for EGFR)

    Returns:
        List of target dictionaries with target_chembl_id
    """
    target = new_client.target

    if accession:
        # Search for targets containing this accession in their components
        try:
            # The ChEMBL API allows filtering by component accession
            records = retry_with_backoff(
                lambda: list(
                    target.filter(target_components__accession=accession).only(
                        ["target_chembl_id", "pref_name", "organism"]
                    )
                )
            )
            return records
        except Exception as e:
            logger.warning("Primary target lookup failed for accession {!r}: {}", accession, e)
            # Fallback: try searching via target_component endpoint
            try:
                target_component = new_client.target_component
                components = retry_with_backoff(
                    lambda: list(target_component.filter(accession=accession))
                )

                # Get unique target IDs from components
                target_ids = set()
                for comp in components:
                    if "targets" in comp:
                        for t in comp["targets"]:
                            if "target_chembl_id" in t:
                                target_ids.add(t["target_chembl_id"])

                if target_ids:
                    return [{"target_chembl_id": tid} for tid in target_ids]
            except Exception as e2:
                logger.warning("Fallback target_component lookup also failed for accession {!r}: {}", accession, e2)

            return []

    return []

UniProt → ChEMBL target mapping

UniProtToChEMBLResult dataclass

UniProtToChEMBLResult(resolved_target_chembl_ids: list[str], per_input: Mapping[str, list[str]], unmapped: list[str])

parse_uniprot_accessions

parse_uniprot_accessions(text: str) -> list[str]
Source code in src/dta_gnn/io/target_mapping.py
def parse_uniprot_accessions(text: str) -> list[str]:
    if not (text or "").strip():
        raise ValueError("No UniProt accessions provided")

    raw = re.split(r"[\s,;]+", text.strip())
    accessions = [t.upper() for t in raw if t]

    bad = [a for a in accessions if not _UNIPROT_RE.match(a)]
    if bad:
        raise ValueError(f"Invalid UniProt accession(s): {', '.join(bad)}")

    # preserve order, unique
    seen: set[str] = set()
    out: list[str] = []
    for a in accessions:
        if a not in seen:
            out.append(a)
            seen.add(a)
    return out

parse_chembl_target_ids

parse_chembl_target_ids(text: str) -> list[str]
Source code in src/dta_gnn/io/target_mapping.py
def parse_chembl_target_ids(text: str) -> list[str]:
    if not (text or "").strip():
        raise ValueError("No ChEMBL target IDs provided")

    raw = re.split(r"[\s,;]+", text.strip())
    ids = [t.upper() for t in raw if t]

    bad = [t for t in ids if not _CHEMBL_TARGET_RE.match(t)]
    if bad:
        raise ValueError(f"Invalid ChEMBL target ID(s): {', '.join(bad)}")

    seen: set[str] = set()
    out: list[str] = []
    for t in ids:
        if t not in seen:
            out.append(t)
            seen.add(t)
    return out

map_uniprot_to_chembl_targets_sqlite

map_uniprot_to_chembl_targets_sqlite(sqlite_path: str | Path, accessions: Iterable[str]) -> UniProtToChEMBLResult
Source code in src/dta_gnn/io/target_mapping.py
def map_uniprot_to_chembl_targets_sqlite(
    sqlite_path: str | Path,
    accessions: Iterable[str],
) -> UniProtToChEMBLResult:
    path = Path(sqlite_path)
    if not path.exists():
        raise FileNotFoundError(f"ChEMBL SQLite DB not found: {path}")

    input_accessions = [a.upper() for a in accessions]
    per_input: dict[str, list[str]] = {a: [] for a in input_accessions}

    if not input_accessions:
        return UniProtToChEMBLResult(
            resolved_target_chembl_ids=[], per_input=per_input, unmapped=[]
        )

    placeholders = ",".join(["?"] * len(input_accessions))
    query = f"""
        SELECT cs.accession, td.chembl_id
        FROM component_sequences cs
        JOIN target_components tc ON tc.component_id = cs.component_id
        JOIN target_dictionary td ON td.tid = tc.tid
        WHERE cs.accession IN ({placeholders})
    """

    with sqlite3.connect(str(path)) as conn:
        rows = conn.execute(query, input_accessions).fetchall()

    for accession, chembl_id in rows:
        if accession is None or chembl_id is None:
            continue
        a = str(accession).upper()
        t = str(chembl_id).upper()
        if a in per_input and t not in per_input[a]:
            per_input[a].append(t)

    resolved = sorted({tid for tids in per_input.values() for tid in tids})
    unmapped = [a for a, tids in per_input.items() if not tids]

    return UniProtToChEMBLResult(
        resolved_target_chembl_ids=resolved,
        per_input=per_input,
        unmapped=unmapped,
    )

map_uniprot_to_chembl_targets_web

map_uniprot_to_chembl_targets_web(accessions: Iterable[str]) -> UniProtToChEMBLResult

Web-based UniProt→ChEMBL mapping.

This is implemented as a thin fallback to the existing ChEMBL web client logic in the app. It keeps the API stable for the UI.

Source code in src/dta_gnn/io/target_mapping.py
def map_uniprot_to_chembl_targets_web(
    accessions: Iterable[str],
) -> UniProtToChEMBLResult:
    """Web-based UniProt→ChEMBL mapping.

    This is implemented as a thin fallback to the existing ChEMBL web client
    logic in the app. It keeps the API stable for the UI.
    """

    from dta_gnn.io.web_source import ChemblWebSource

    input_accessions = [a.upper() for a in accessions]
    per_input: dict[str, list[str]] = {a: [] for a in input_accessions}

    src = ChemblWebSource()
    for a in input_accessions:
        try:
            # Best-effort: match targets whose component accession matches.
            targets = src.get_targets(accession=a)
            per_input[a] = sorted(
                {t["target_chembl_id"] for t in targets if "target_chembl_id" in t}
            )
        except Exception as exc:
            logger.warning(
                "ChEMBL Web API lookup failed for UniProt accession {!r}: {}", a, exc
            )
            per_input[a] = []

    resolved = sorted({tid for tids in per_input.values() for tid in tids})
    unmapped = [a for a, tids in per_input.items() if not tids]

    return UniProtToChEMBLResult(
        resolved_target_chembl_ids=resolved,
        per_input=per_input,
        unmapped=unmapped,
    )

Run directories

RunDirResult dataclass

RunDirResult(run_dir: Path, current_link: Path)

create_run_dir

create_run_dir(*, runs_root: str | Path = 'runs') -> Path

Create a new timestamped run directory and update runs/current.

Returns the created run directory path.

Source code in src/dta_gnn/io/runs.py
def create_run_dir(*, runs_root: str | Path = "runs") -> Path:
    """Create a new timestamped run directory and update `runs/current`.

    Returns the created run directory path.
    """

    runs_root = Path(runs_root)
    runs_root.mkdir(parents=True, exist_ok=True)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = runs_root / ts

    # Ensure uniqueness if called multiple times within the same second.
    suffix = 1
    while run_dir.exists():
        suffix += 1
        run_dir = runs_root / f"{ts}_{suffix}"

    run_dir.mkdir(parents=True, exist_ok=False)

    current = runs_root / "current"
    try:
        if current.is_symlink() or current.exists():
            if current.is_dir() and not current.is_symlink():
                # Avoid deleting a real directory with contents unexpectedly.
                # If it's a directory, we keep it and just return the run_dir.
                return run_dir
            current.unlink()

        # Prefer a relative symlink (matches existing repo structure).
        os.symlink(run_dir.name, str(current))
    except OSError:
        # Symlinks can fail on some platforms; fall back to writing a pointer.
        # The UI also resolves runs/current; if it's a file, it won't work.
        # So in that case, create a directory and store a marker.
        current.mkdir(parents=True, exist_ok=True)
        (current / "RUN_DIR.txt").write_text(str(run_dir.resolve()), encoding="utf-8")

    return run_dir

resolve_run_dir

resolve_run_dir(run_dir: str | Path | None) -> Path | None

Resolve and normalize a run directory path.

Parameters:

Name Type Description Default
run_dir str | Path | None

Path to run directory (string or Path object)

required

Returns:

Type Description
Path | None

Resolved Path object, or None if input is None or resolution fails

Source code in src/dta_gnn/io/runs.py
def resolve_run_dir(run_dir: str | Path | None) -> Path | None:
    """Resolve and normalize a run directory path.

    Args:
        run_dir: Path to run directory (string or Path object)

    Returns:
        Resolved Path object, or None if input is None or resolution fails
    """
    if run_dir is None:
        return None
    try:
        return Path(run_dir).expanduser().resolve()
    except Exception as e:
        logger.warning("Failed to resolve run_dir path {!r}: {}", run_dir, e)
        return None

resolve_current_run_dir

resolve_current_run_dir(*, hint: str = 'Build a dataset first.') -> Path

Resolve the current run folder.

Prefers runs/current if it exists (dir or symlink). If missing, raises FileNotFoundError.

Parameters:

Name Type Description Default
hint str

Optional hint message to include in error

'Build a dataset first.'

Returns:

Type Description
Path

Resolved Path to current run directory

Raises:

Type Description
FileNotFoundError

If runs/current does not exist

Source code in src/dta_gnn/io/runs.py
def resolve_current_run_dir(*, hint: str = "Build a dataset first.") -> Path:
    """Resolve the current run folder.

    Prefers `runs/current` if it exists (dir or symlink). If missing, raises
    FileNotFoundError.

    Args:
        hint: Optional hint message to include in error

    Returns:
        Resolved Path to current run directory

    Raises:
        FileNotFoundError: If runs/current does not exist
    """
    run_dir = Path("runs") / "current"
    if run_dir.exists():
        try:
            return run_dir.resolve()
        except Exception as e:
            logger.warning("Failed to resolve current run directory, returning as-is: {}", e)
            return run_dir
    raise FileNotFoundError(f"No current run found. Looked for 'runs/current'. {hint}")

Database downloader

download_chembl_db

download_chembl_db(version: str = LATEST_CHEMBL_VERSION, output_dir: str = '.') -> str

Download and extract ChEMBL SQLite database. Returns path to the extracted .db file.

Source code in src/dta_gnn/io/downloader.py
def download_chembl_db(
    version: str = LATEST_CHEMBL_VERSION, output_dir: str = "."
) -> str:
    """
    Download and extract ChEMBL SQLite database.
    Returns path to the extracted .db file.
    """
    version = str(version)
    url = BASE_URL.format(version, version)
    filename = f"chembl_{version}_sqlite.tar.gz"
    output_path = Path(output_dir) / filename

    logger.info(f"Downloading ChEMBL {version} from {url}...")

    # Stream download with progress
    if not output_path.exists():
        try:
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                total_size = int(r.headers.get("content-length", 0))

                with open(output_path, "wb") as f, tqdm(
                    desc=filename,
                    total=total_size,
                    unit="iB",
                    unit_scale=True,
                    unit_divisor=1024,
                ) as bar:
                    for chunk in r.iter_content(chunk_size=8192):
                        size = f.write(chunk)
                        bar.update(size)
        except Exception as e:
            if output_path.exists():
                os.remove(output_path)
            raise RuntimeError(f"Download failed: {e}")
    else:
        logger.info("Archive already exists, skipping download.")

    # Extract
    logger.info("Extracting...")
    extract_path = Path(output_dir)

    # The tar usually contains a folder like chembl_33/chembl_33_sqlite/chembl_33.db
    # We strip components? Tarfile doesn't support strip_components easily.
    # We extract and then find the db file.

    # Check if DB already extracted? (Simplified check)
    # The internal structure is typically: chembl_33/chembl_33_sqlite/chembl_33.db
    # Let's extract everything and find it.

    with tarfile.open(output_path, "r:gz") as tar:
        # Security check for tarbomb/absolute paths?
        # Python 3.12 has filter='data', but let's just do default as we trust EBI.
        tar.extractall(path=output_dir)

        # Locate the .db file
        for member in tar.getmembers():
            if member.name.endswith(".db"):
                logger.info(f"Found DB at {member.name}")
                # We return this path.
                return str(extract_path / member.name)

    raise FileNotFoundError("Could not find .db file in archive")

File / CSV utilities

CsvPreview dataclass

CsvPreview(df: DataFrame | None, error: str | None = None)

Result of CSV preview operation.

normalize_csv_path

normalize_csv_path(path: str | None) -> str | None

Normalize a CSV file path string.

Parameters:

Name Type Description Default
path str | None

Path string to normalize

required

Returns:

Type Description
str | None

Normalized path string, or None if input is empty/None

Source code in src/dta_gnn/io/utils.py
def normalize_csv_path(path: str | None) -> str | None:
    """Normalize a CSV file path string.

    Args:
        path: Path string to normalize

    Returns:
        Normalized path string, or None if input is empty/None
    """
    if not path:
        return None
    p = str(path).strip()
    return p or None

preview_csv

preview_csv(path: str | None, n: int = 50) -> pd.DataFrame | None

Preview a CSV file (wrapper that returns only DataFrame).

Parameters:

Name Type Description Default
path str | None

Path to CSV file

required
n int

Number of rows to read (default: 50)

50

Returns:

Type Description
DataFrame | None

DataFrame with first n rows, or None if error

Source code in src/dta_gnn/io/utils.py
def preview_csv(path: str | None, n: int = 50) -> pd.DataFrame | None:
    """Preview a CSV file (wrapper that returns only DataFrame).

    Args:
        path: Path to CSV file
        n: Number of rows to read (default: 50)

    Returns:
        DataFrame with first n rows, or None if error
    """
    return preview_csv_with_error(path, n=n).df

preview_csv_with_error

preview_csv_with_error(path: str | None, n: int = 50) -> CsvPreview

Preview a CSV file with error handling.

Parameters:

Name Type Description Default
path str | None

Path to CSV file

required
n int

Number of rows to read (default: 50)

50

Returns:

Type Description
CsvPreview

CsvPreview object with DataFrame and optional error message

Source code in src/dta_gnn/io/utils.py
def preview_csv_with_error(path: str | None, n: int = 50) -> CsvPreview:
    """Preview a CSV file with error handling.

    Args:
        path: Path to CSV file
        n: Number of rows to read (default: 50)

    Returns:
        CsvPreview object with DataFrame and optional error message
    """
    path = normalize_csv_path(path)
    if not path:
        return CsvPreview(df=None, error=None)
    try:
        return CsvPreview(df=pd.read_csv(path, nrows=n), error=None)
    except Exception as e:
        return CsvPreview(df=None, error=f"Could not read CSV: {e}")

iter_existing_files

iter_existing_files(paths: Iterable[str | None]) -> list[str]

Filter a list of paths to only those that exist.

Parameters:

Name Type Description Default
paths Iterable[str | None]

Iterable of file paths (may include None)

required

Returns:

Type Description
list[str]

List of existing file paths

Source code in src/dta_gnn/io/utils.py
def iter_existing_files(paths: Iterable[str | None]) -> list[str]:
    """Filter a list of paths to only those that exist.

    Args:
        paths: Iterable of file paths (may include None)

    Returns:
        List of existing file paths
    """
    existing: list[str] = []
    for p in paths:
        if not p:
            continue
        try:
            if Path(p).exists():
                existing.append(p)
        except Exception as e:
            logger.debug("Could not check file existence for {!r}: {}", p, e)
            continue
    return existing

find_chembl_sqlite_dbs

find_chembl_sqlite_dbs() -> list[str]

Find available ChEMBL SQLite DB files under a chembl_dbs/ folder.

Searches both the current working directory and the repo root (when running from source). Returns absolute paths.

Returns:

Type Description
list[str]

Sorted list of absolute paths to SQLite database files

Source code in src/dta_gnn/io/utils.py
def find_chembl_sqlite_dbs() -> list[str]:
    """Find available ChEMBL SQLite DB files under a `chembl_dbs/` folder.

    Searches both the current working directory and the repo root (when running
    from source). Returns absolute paths.

    Returns:
        Sorted list of absolute paths to SQLite database files
    """
    candidates: list[Path] = []
    cwd_dir = Path.cwd() / "chembl_dbs"
    if cwd_dir.exists() and cwd_dir.is_dir():
        candidates.append(cwd_dir)

    # When running from source, try to find repo root
    # This function may be called from various locations, so we try multiple approaches
    try:
        # Try to find repo root by looking for common markers
        current_file = Path(__file__).resolve()
        # io/utils.py is at src/dta_gnn/io/utils.py, so repo root is 3 levels up
        repo_dir = current_file.parents[3]
        repo_candidate = repo_dir / "chembl_dbs"
        if repo_candidate.exists() and repo_candidate.is_dir():
            candidates.append(repo_candidate)
    except Exception as e:
        logger.debug("Could not determine repo root for chembl_dbs discovery: {}", e)

    exts = {".db", ".sqlite", ".sqlite3"}
    found: list[Path] = []
    for base in candidates:
        for p in base.rglob("*"):
            if p.is_file() and p.suffix.lower() in exts:
                found.append(p.resolve())

    # Stable ordering for a nicer UX
    return sorted({str(p) for p in found})

Cleaning (dta_gnn.cleaning)

standardize_activities

standardize_activities(df: DataFrame, convert_to_pchembl: bool = True, drop_censored: bool = False) -> pd.DataFrame

Standardize activity values. - Filters rows with missing standard_value. - Converts nanomolar units to molar if needed (though ChEMBL standard_value is usually nM). - Calculates pChEMBL if missing and requested.

Source code in src/dta_gnn/cleaning/functions.py
def standardize_activities(
    df: pd.DataFrame, convert_to_pchembl: bool = True, drop_censored: bool = False
) -> pd.DataFrame:
    """
    Standardize activity values.
    - Filters rows with missing standard_value.
    - Converts nanomolar units to molar if needed (though ChEMBL standard_value is usually nM).
    - Calculates pChEMBL if missing and requested.
    """
    df = df.copy()

    # Ensure numeric columns are actually numeric
    df["standard_value"] = pd.to_numeric(df["standard_value"], errors="coerce")
    if "pchembl_value" in df.columns:
        df["pchembl_value"] = pd.to_numeric(df["pchembl_value"], errors="coerce")

    # Drop missing values
    df = df.dropna(subset=["standard_value", "standard_units"])

    # Handle censored values (>, <)
    # If standard_relation exists, we can use it.
    if "standard_relation" in df.columns:
        if drop_censored:
            # Keep only exact measurements
            df = df[
                df["standard_relation"].isin(["=", ""])
            ]  # Some might be empty string?
        else:
            pass  # Keep them for now, might be used for binary labeling depending on direction

    # ChEMBL standard_value is typically in nM.
    # pChEMBL = -log10(value_molar)
    # value_molar = value_nM * 1e-9

    if convert_to_pchembl and "pchembl_value" in df.columns:
        # If pchembl_value column exists but has NaNs, try to fill them
        mask = df["pchembl_value"].isna()
        # Only fill if units are nM (most common)
        nm_mask = df["standard_units"] == "nM"

        # Calculate for those missing pchembl but having nM values
        to_calc = mask & nm_mask
        if to_calc.any():
            # Avoid log(0)
            valid_val = df.loc[to_calc, "standard_value"] > 0
            df.loc[to_calc & valid_val, "pchembl_value"] = -np.log10(
                df.loc[to_calc & valid_val, "standard_value"] * 1e-9
            )

    return df

aggregate_duplicates

aggregate_duplicates(df: DataFrame, group_cols: list = ['molecule_chembl_id', 'target_chembl_id'], agg_method: Literal['median', 'mean', 'max', 'min'] = 'median') -> pd.DataFrame

Deduplicate measurements for the same drug-target pair.

Source code in src/dta_gnn/cleaning/functions.py
def aggregate_duplicates(
    df: pd.DataFrame,
    group_cols: list = ["molecule_chembl_id", "target_chembl_id"],
    agg_method: Literal["median", "mean", "max", "min"] = "median",
) -> pd.DataFrame:
    """
    Deduplicate measurements for the same drug-target pair.
    """
    if df.empty:
        return df

    # We aggregate pchembl_value if available, else standard_value
    target_col = "pchembl_value" if "pchembl_value" in df.columns else "standard_value"

    # Drop rows where the target column is NaN before aggregating
    df_clean = df.dropna(subset=[target_col])

    agg_dict = {target_col: agg_method}

    if "year" in df_clean.columns:
        agg_dict["year"] = "min"

    if "canonical_smiles" in df_clean.columns:
        agg_dict["canonical_smiles"] = "first"

    grouped = df_clean.groupby(group_cols).agg(agg_dict).reset_index()

    return grouped

canonicalize_smiles

canonicalize_smiles(smiles: str) -> Optional[str]

Canonicalize a single SMILES string using RDKit.

Source code in src/dta_gnn/cleaning/functions.py
def canonicalize_smiles(smiles: str) -> Optional[str]:
    """Canonicalize a single SMILES string using RDKit."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    except Exception as e:
        logger.warning("Failed to canonicalize SMILES {!r}: {}", smiles, e)
    return None

validate_split_sizes

validate_split_sizes(test_size: float, val_size: float) -> None

Validate test and validation split sizes.

Parameters:

Name Type Description Default
test_size float

Test set size (fraction)

required
val_size float

Validation set size (fraction)

required

Raises:

Type Description
ValueError

If sizes are invalid (not numbers, negative, or sum >= 1.0)

Source code in src/dta_gnn/cleaning/validation.py
def validate_split_sizes(test_size: float, val_size: float) -> None:
    """Validate test and validation split sizes.

    Args:
        test_size: Test set size (fraction)
        val_size: Validation set size (fraction)

    Raises:
        ValueError: If sizes are invalid (not numbers, negative, or sum >= 1.0)
    """
    try:
        ts = float(test_size)
        vs = float(val_size)
    except (ValueError, TypeError):
        raise ValueError("Test/Validation sizes must be numbers.")
    if ts < 0 or vs < 0:
        raise ValueError("Test/Validation sizes must be non-negative.")
    if ts + vs >= 1.0:
        raise ValueError("Test size + validation size must be < 1.0.")

validate_sqlite_path

validate_sqlite_path(source: str, sqlite_path: str | None) -> None

Validate SQLite database path if source is 'sqlite'.

Parameters:

Name Type Description Default
source str

Data source type ('sqlite' or 'web')

required
sqlite_path str | None

Path to SQLite database file

required

Raises:

Type Description
ValueError

If source is 'sqlite' and path is missing or invalid

Source code in src/dta_gnn/cleaning/validation.py
def validate_sqlite_path(source: str, sqlite_path: str | None) -> None:
    """Validate SQLite database path if source is 'sqlite'.

    Args:
        source: Data source type ('sqlite' or 'web')
        sqlite_path: Path to SQLite database file

    Raises:
        ValueError: If source is 'sqlite' and path is missing or invalid
    """
    if (source or "").strip() != "sqlite":
        return
    p = normalize_csv_path(sqlite_path)
    if not p:
        raise ValueError("SQLite DB path is required when Data Source is 'sqlite'.")
    from pathlib import Path

    if not Path(p).exists():
        raise ValueError(f"SQLite DB not found: {p}")

Splitting (dta_gnn.splits)

split_random

split_random(df: DataFrame, test_size: float = 0.2, val_size: float = 0.1, seed: int = 42) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]

Random split into Train/Val/Test.

Source code in src/dta_gnn/splits/strategies.py
def split_random(
    df: pd.DataFrame, test_size: float = 0.2, val_size: float = 0.1, seed: int = 42
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Random split into Train/Val/Test.
    """
    if len(df) == 0:
        # Create empty DataFrames with same dtypes
        col_list = df.columns
        dt = df.dtypes
        return (
            pd.DataFrame(columns=col_list).astype(dt),
            pd.DataFrame(columns=col_list).astype(dt),
            pd.DataFrame(columns=col_list).astype(dt),
            pd.DataFrame(columns=col_list).astype(dt),
        )

    train, test = train_test_split(df, test_size=test_size, random_state=seed)

    # Further split train to get val
    # val_size is relative to TOTAL, so we need to calculate fraction of REMAINING (Train)
    # remaining = 1 - test_size
    # fraction = val_size / remaining
    if val_size > 0:
        relative_val_size = val_size / (1 - test_size)
        train, val = train_test_split(
            train, test_size=relative_val_size, random_state=seed
        )
        val["split"] = "val"
    else:
        # Create empty DataFrame with same dtypes to avoid FutureWarning
        val = pd.DataFrame(columns=df.columns).astype(df.dtypes)

    # Add split column
    train["split"] = "train"
    test["split"] = "test"

    # Filter out empty DataFrames before concat to avoid FutureWarning
    dfs_to_concat = [d for d in [train, val, test] if len(d) > 0]
    if not dfs_to_concat:
        result = pd.DataFrame(columns=df.columns)
    else:
        result = pd.concat(dfs_to_concat, ignore_index=True)

    return result, train, val, test

split_cold_drug_scaffold

split_cold_drug_scaffold(df: DataFrame, smiles_col: str = 'smiles', test_size: float = 0.2, val_size: float = 0.1, seed: int = 42) -> pd.DataFrame

Scaffold split (Cold Drug).

Source code in src/dta_gnn/splits/strategies.py
def split_cold_drug_scaffold(
    df: pd.DataFrame,
    smiles_col: str = "smiles",
    test_size: float = 0.2,
    val_size: float = 0.1,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Scaffold split (Cold Drug).
    """
    # 1. Generate scaffolds
    scaffolds = {}
    for idx, row in df.iterrows():
        smi = row[smiles_col]
        try:
            scaff = MurckoScaffold.MurckoScaffoldSmiles(smi)
            if scaff not in scaffolds:
                scaffolds[scaff] = []
            scaffolds[scaff].append(idx)
        except Exception as e:
            logger.debug("Skipping SMILES {!r} during scaffold split: {}", smi, e)

    # 2. Sort scaffolds by size (to balance) or shuffle?
    # Standard practice: sort by num molecules to put rare scaffolds in test?
    # Or random shuffle scaffolds? DeepChem does random sort.

    scaffold_sets = list(scaffolds.values())

    # Deterministic shuffle of scaffold groups
    rng = np.random.default_rng(seed)
    rng.shuffle(scaffold_sets)

    train_idxs, val_idxs, test_idxs = [], [], []
    train_cutoff = len(df) * (1 - test_size - val_size)
    val_cutoff = len(df) * (1 - test_size)

    current_count = 0
    for group in scaffold_sets:
        if current_count < train_cutoff:
            train_idxs.extend(group)
        elif current_count < val_cutoff:
            val_idxs.extend(group)
        else:
            test_idxs.extend(group)
        current_count += len(group)

    df.loc[train_idxs, "split"] = "train"
    df.loc[val_idxs, "split"] = "val"
    df.loc[test_idxs, "split"] = "test"

    return df

split_temporal

split_temporal(df: DataFrame, year_col: str = 'year', split_year: int = 2022, val_size: float = 0.1) -> pd.DataFrame

Temporal split based on year. Train: year < split_year Test: year >= split_year Val: random subset of Train (or could be time-based if requested, but simple random of past is standard)

Source code in src/dta_gnn/splits/strategies.py
def split_temporal(
    df: pd.DataFrame,
    year_col: str = "year",
    split_year: int = 2022,
    val_size: float = 0.1,
) -> pd.DataFrame:
    """
    Temporal split based on year.
    Train: year < split_year
    Test: year >= split_year
    Val: random subset of Train (or could be time-based if requested, but simple random of past is standard)
    """
    # Ensure year is numeric
    df = df.copy()
    df[year_col] = pd.to_numeric(df[year_col], errors="coerce")

    train_mask = df[year_col] < split_year
    test_mask = df[year_col] >= split_year

    # Split Train further into Train/Val if val_size > 0
    train_df = df[train_mask].copy()
    test_df = df[test_mask].copy()

    if val_size > 0 and len(train_df) > 0:
        try:
            tr, va = train_test_split(train_df, test_size=val_size, random_state=42)
            train_df = tr
            val_df = va
            val_df["split"] = "val"
        except ValueError:
            # Create empty DataFrame with same dtypes to avoid FutureWarning
            val_df = pd.DataFrame(columns=df.columns).astype(df.dtypes)
    else:
        # Create empty DataFrame with same dtypes to avoid FutureWarning
        val_df = pd.DataFrame(columns=df.columns).astype(df.dtypes)

    train_df["split"] = "train"
    test_df["split"] = "test"

    # Filter out empty DataFrames before concat to avoid FutureWarning
    dfs_to_concat = [d for d in [train_df, val_df, test_df] if len(d) > 0]
    if not dfs_to_concat:
        return pd.DataFrame(columns=df.columns)
    return pd.concat(dfs_to_concat, ignore_index=True)

Featurisation (dta_gnn.features)

Morgan fingerprints

calculate_morgan_fingerprints

calculate_morgan_fingerprints(df: DataFrame, smiles_col: str = 'smiles', radius: int = 2, n_bits: int = 2048, *, out_col: str = 'morgan_fingerprint', drop_failures: bool = True) -> pd.DataFrame

Calculate Morgan fingerprints for molecules in the DataFrame.

Returns a copy of df with an added fingerprint column containing bitstrings.

Parameters:

Name Type Description Default
df DataFrame

Input dataframe

required
smiles_col str

Column containing SMILES strings

'smiles'
radius int

Morgan radius (2 => ECFP4)

2
n_bits int

Fingerprint bit length

2048
out_col str

Output column name

'morgan_fingerprint'
drop_failures bool

Whether to drop rows that fail featurization

True
Source code in src/dta_gnn/features/__init__.py
def calculate_morgan_fingerprints(
    df: pd.DataFrame,
    smiles_col: str = "smiles",
    radius: int = 2,
    n_bits: int = 2048,
    *,
    out_col: str = "morgan_fingerprint",
    drop_failures: bool = True,
) -> pd.DataFrame:
    """Calculate Morgan fingerprints for molecules in the DataFrame.

    Returns a copy of `df` with an added fingerprint column containing bitstrings.

    Args:
        df: Input dataframe
        smiles_col: Column containing SMILES strings
        radius: Morgan radius (2 => ECFP4)
        n_bits: Fingerprint bit length
        out_col: Output column name
        drop_failures: Whether to drop rows that fail featurization
    """

    from rdkit import Chem
    from rdkit.Chem import AllChem

    logger.info(f"Calculating Morgan fingerprints (r={radius}, n={n_bits})...")

    fps: list[str | None] = []
    indices_to_drop: list[int] = []

    # Prefer the new generator API if available.
    mfgen = None
    try:
        mfgen = AllChem.GetMorganGenerator(radius=int(radius), fpSize=int(n_bits))
    except Exception:
        mfgen = None

    for idx, row in df.iterrows():
        smi = row.get(smiles_col)
        if smi is None or pd.isna(smi) or not str(smi).strip():
            fps.append(None)
            indices_to_drop.append(int(idx))
            continue

        mol = Chem.MolFromSmiles(str(smi))
        if mol is None:
            fps.append(None)
            indices_to_drop.append(int(idx))
            continue

        try:
            if mfgen is not None:
                fp = mfgen.GetFingerprint(mol)
                fps.append(fp.ToBitString())
            else:
                fp = AllChem.GetMorganFingerprintAsBitVect(
                    mol, int(radius), nBits=int(n_bits)
                )
                fps.append(fp.ToBitString())
        except Exception:
            fps.append(None)
            indices_to_drop.append(int(idx))

    out = df.copy()
    out[out_col] = fps

    if drop_failures and indices_to_drop:
        logger.warning(f"Failed to featurize {len(indices_to_drop)} molecules.")
        out = out.drop(indices_to_drop)

    return out

2-D molecular graphs

MoleculeGraph2D dataclass

MoleculeGraph2D(molecule_chembl_id: str, atom_type: ndarray, atom_feat: ndarray, edge_index: ndarray, edge_attr: ndarray)

smiles_to_graph_2d

smiles_to_graph_2d(*, molecule_chembl_id: str, smiles: str) -> MoleculeGraph2D | None

Convert a SMILES string into a simple 2D molecular graph.

Node features are fixed-size (6) numeric features. Edge features are fixed-size (6) numeric features.

Source code in src/dta_gnn/features/molecule_graphs.py
def smiles_to_graph_2d(
    *, molecule_chembl_id: str, smiles: str
) -> MoleculeGraph2D | None:
    """Convert a SMILES string into a simple 2D molecular graph.

    Node features are fixed-size (6) numeric features.
    Edge features are fixed-size (6) numeric features.
    """

    from rdkit import Chem

    mol = Chem.MolFromSmiles(str(smiles))
    if mol is None:
        return None

    n_atoms = mol.GetNumAtoms()
    if n_atoms == 0:
        return None

    atom_type = np.zeros((n_atoms,), dtype=np.int64)
    atom_feat = np.zeros((n_atoms, 6), dtype=np.float32)

    for i, atom in enumerate(mol.GetAtoms()):
        atomic_num = int(atom.GetAtomicNum())
        atom_type[i] = np.int64(atomic_num)

        # 6 node features (simple, stable)
        atom_feat[i, 0] = float(atomic_num)
        atom_feat[i, 1] = float(atom.GetTotalDegree())
        atom_feat[i, 2] = float(atom.GetFormalCharge())
        atom_feat[i, 3] = float(atom.GetTotalNumHs(includeNeighbors=True))
        atom_feat[i, 4] = 1.0 if atom.GetIsAromatic() else 0.0
        atom_feat[i, 5] = float(atom.GetMass())

    # Build directed edges
    edge_src: list[int] = []
    edge_dst: list[int] = []
    edge_attr_rows: list[list[float]] = []

    def _bond_features(bond: "Chem.Bond") -> list[float]:
        bt = bond.GetBondType()
        is_single = 1.0 if bt == Chem.BondType.SINGLE else 0.0
        is_double = 1.0 if bt == Chem.BondType.DOUBLE else 0.0
        is_triple = 1.0 if bt == Chem.BondType.TRIPLE else 0.0
        is_aromatic = 1.0 if bond.GetIsAromatic() else 0.0
        is_conj = 1.0 if bond.GetIsConjugated() else 0.0
        is_ring = 1.0 if bond.IsInRing() else 0.0
        return [is_single, is_double, is_triple, is_aromatic, is_conj, is_ring]

    for bond in mol.GetBonds():
        a = int(bond.GetBeginAtomIdx())
        b = int(bond.GetEndAtomIdx())
        bf = _bond_features(bond)

        edge_src.append(a)
        edge_dst.append(b)
        edge_attr_rows.append(bf)

        edge_src.append(b)
        edge_dst.append(a)
        edge_attr_rows.append(bf)

    if edge_src:
        edge_index = np.asarray([edge_src, edge_dst], dtype=np.int64)
        edge_attr = np.asarray(edge_attr_rows, dtype=np.float32)
    else:
        edge_index = np.zeros((2, 0), dtype=np.int64)
        edge_attr = np.zeros((0, 6), dtype=np.float32)

    return MoleculeGraph2D(
        molecule_chembl_id=str(molecule_chembl_id),
        atom_type=atom_type,
        atom_feat=atom_feat,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

build_graphs_2d

build_graphs_2d(*, molecules: Iterable[tuple[str, str]], drop_failures: bool = True) -> list[MoleculeGraph2D]
Source code in src/dta_gnn/features/molecule_graphs.py
def build_graphs_2d(
    *,
    molecules: Iterable[tuple[str, str]],
    drop_failures: bool = True,
) -> list[MoleculeGraph2D]:
    graphs: list[MoleculeGraph2D] = []
    for mid, smi in molecules:
        g = smiles_to_graph_2d(molecule_chembl_id=str(mid), smiles=str(smi))
        if g is None:
            if drop_failures:
                continue
            raise ValueError(f"Failed to parse SMILES for molecule {mid!r}: {smi!r}")
        graphs.append(g)
    return graphs

Models (dta_gnn.models)

Random Forest baseline

train_random_forest_on_run

train_random_forest_on_run(run_dir: str | Path, *, n_estimators: int = 100, max_depth: int | None = 5, random_seed: int = 42) -> RandomForestTrainResult

Train a RandomForest baseline on runs/<run>/dataset.csv + compounds.csv.

Writes: - model_rf.pkl - model_metrics.json - model_predictions.csv

Source code in src/dta_gnn/models/random_forest.py
def train_random_forest_on_run(
    run_dir: str | Path,
    *,
    n_estimators: int = 100,
    max_depth: int | None = 5,
    random_seed: int = 42,
) -> RandomForestTrainResult:
    """Train a RandomForest baseline on `runs/<run>/dataset.csv` + `compounds.csv`.

    Writes:
    - model_rf.pkl
    - model_metrics.json
    - model_predictions.csv
    """

    from sklearn.ensemble import RandomForestRegressor
    from sklearn.metrics import (
        mean_absolute_error,
        mean_squared_error,
        r2_score,
    )
    from scipy.stats import pearsonr, spearmanr

    import joblib

    run_dir = Path(run_dir).resolve()
    dataset_path = run_dir / "dataset.csv"
    compounds_path = run_dir / "compounds.csv"
    if not dataset_path.exists():
        raise FileNotFoundError(f"Missing dataset.csv in run folder: {dataset_path}")
    if not compounds_path.exists():
        raise FileNotFoundError(
            f"Missing compounds.csv in run folder: {compounds_path}"
        )

    df = pd.read_csv(dataset_path)
    compounds = pd.read_csv(compounds_path)

    required = {"molecule_chembl_id", "label"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"dataset.csv missing columns: {sorted(missing)}")

    if "split" not in df.columns:
        raise ValueError(
            "dataset.csv missing 'split' column (expected train/val/test)."
        )

    # If dataset already contains SMILES, don't force a merge that can create
    # suffixes (smiles_x/smiles_y). Otherwise join from compounds.csv.
    smiles_candidates = [c for c in ["smiles", "canonical_smiles"] if c in df.columns]

    if not smiles_candidates:
        if "molecule_chembl_id" not in compounds.columns:
            raise ValueError("compounds.csv must contain 'molecule_chembl_id'.")
        comp_smiles_col = (
            "smiles"
            if "smiles" in compounds.columns
            else (
                "canonical_smiles" if "canonical_smiles" in compounds.columns else None
            )
        )
        if comp_smiles_col is None:
            raise ValueError(
                "compounds.csv must contain a SMILES column ('smiles' or 'canonical_smiles')."
            )

        df = df.merge(
            compounds[["molecule_chembl_id", comp_smiles_col]].rename(
                columns={comp_smiles_col: "smiles"}
            ),
            on="molecule_chembl_id",
            how="left",
        )
    else:
        # Normalize to a single column name.
        if "smiles" not in df.columns and "canonical_smiles" in df.columns:
            df = df.rename(columns={"canonical_smiles": "smiles"})

    # If we still ended up with suffixes (e.g., user-provided dataset.csv includes
    # smiles and we merged anyway in a previous run), coalesce them.
    if "smiles" not in df.columns:
        if "smiles_x" in df.columns and "smiles_y" in df.columns:
            df["smiles"] = df["smiles_x"].combine_first(df["smiles_y"])
        elif "smiles_x" in df.columns:
            df["smiles"] = df["smiles_x"]
        elif "smiles_y" in df.columns:
            df["smiles"] = df["smiles_y"]

    if "smiles" not in df.columns:
        raise ValueError(
            "Could not locate a SMILES column after merging dataset.csv and compounds.csv."
        )

    df = df.dropna(subset=["smiles", "label", "split"]).copy()
    if df.empty:
        raise ValueError(
            "No rows left after joining SMILES and dropping missing values."
        )

    task_type = _infer_task_type_from_labels(df)
    y_all = df["label"].astype(float).to_numpy()
    X_all, valid_mask = _morgan_fingerprints(
        df["smiles"].astype(str).tolist(), radius=2, n_bits=2048
    )
    if X_all.shape[0] == 0:
        raise ValueError("No valid SMILES to build fingerprints.")

    df_valid = df.loc[valid_mask].reset_index(drop=True)
    y_valid = y_all[valid_mask]
    X_valid = X_all

    train_mask = df_valid["split"].astype(str) == "train"
    val_mask = df_valid["split"].astype(str) == "val"
    test_mask = df_valid["split"].astype(str) == "test"
    if int(train_mask.sum()) < 2:
        raise ValueError("Not enough training rows after featurization.")

    X_train, y_train = X_valid[train_mask.values], y_valid[train_mask.values]

    from sklearn.ensemble import RandomForestClassifier

    if task_type == "classification":
        model = RandomForestClassifier(
            n_estimators=int(n_estimators),
            max_depth=max_depth,
            random_state=int(random_seed),
            n_jobs=-1,
        )
        y_train = y_train.astype(int)
    else:
        model = RandomForestRegressor(
            n_estimators=int(n_estimators),
            max_depth=max_depth,
            random_state=int(random_seed),
            n_jobs=-1,
        )

    model.fit(X_train, y_train)

    def _eval(split_name: str, mask: np.ndarray) -> dict[str, float] | None:
        if int(mask.sum()) == 0:
            return None
        Xs, ys = X_valid[mask], y_valid[mask]
        yhat = model.predict(Xs)

        if task_type == "classification":
            from sklearn.metrics import accuracy_score
            ys_int = ys.astype(int)
            accuracy = float(accuracy_score(ys_int, yhat))
            return {"accuracy": accuracy}
        else:
            rmse = float(np.sqrt(mean_squared_error(ys, yhat)))
            mae = float(mean_absolute_error(ys, yhat))
            r2 = float(r2_score(ys, yhat)) if len(np.unique(ys)) > 1 else 0.0
            pearson_r = float(pearsonr(ys, yhat)[0]) if len(np.unique(ys)) > 1 else None
            spearman_r = float(spearmanr(ys, yhat)[0]) if len(np.unique(ys)) > 1 else None
            return {"rmse": rmse, "mae": mae, "r2": r2, "pearson_r": pearson_r, "spearman_r": spearman_r}

    splits: dict[str, dict[str, float]] = {}
    for name, mask in (
        ("train", train_mask.values),
        ("val", val_mask.values),
        ("test", test_mask.values),
    ):
        m = _eval(name, mask)
        if m is not None:
            splits[name] = m

    metrics = {
        "model_type": "RandomForest",
        "task_type": task_type,
        "splits": splits,
    }

    # Save artifacts
    model_path = run_dir / "model_rf.pkl"
    metrics_path = run_dir / "model_metrics.json"
    predictions_path = run_dir / "model_predictions.csv"

    joblib.dump(model, model_path)
    metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")

    # Predictions file (val/test rows; train is usually too large for UI preview)
    pred_rows = df_valid.loc[val_mask | test_mask].copy()
    if not pred_rows.empty:
        X_pred = X_valid[(val_mask | test_mask).values]
        y_pred = model.predict(X_pred)
        pred_rows["y_pred"] = y_pred
    pred_rows.to_csv(predictions_path, index=False)

    return RandomForestTrainResult(
        run_dir=run_dir,
        task_type=task_type,
        model_path=model_path,
        metrics_path=metrics_path,
        predictions_path=predictions_path,
        metrics=metrics,
    )

SVR baseline

train_svr_on_run

train_svr_on_run(run_dir: str | Path, *, C: float = 1.0, epsilon: float = 0.5, kernel: str = 'rbf', random_seed: int = 42) -> SvrTrainResult

Train an SVR baseline on Morgan fingerprints.

Writes: - model_svr.pkl - model_metrics_svr.json - model_predictions_svr.csv

Source code in src/dta_gnn/models/svr.py
def train_svr_on_run(
    run_dir: str | Path,
    *,
    C: float = 1.0,
    epsilon: float = 0.5,
    kernel: str = "rbf",
    random_seed: int = 42,
) -> SvrTrainResult:
    """Train an SVR baseline on Morgan fingerprints.

    Writes:
    - model_svr.pkl
    - model_metrics_svr.json
    - model_predictions_svr.csv
    """

    from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
    from scipy.stats import pearsonr, spearmanr
    from sklearn.svm import SVR

    import joblib

    run_dir = Path(run_dir).resolve()
    df = _load_dataset_with_smiles(run_dir)

    y_all = df["label"].astype(float).to_numpy()
    X_all, valid_mask = _morgan_fingerprints(
        df["smiles"].astype(str).tolist(), radius=2, n_bits=2048
    )
    if X_all.shape[0] == 0:
        raise ValueError("No valid SMILES to build fingerprints.")

    df_valid = df.loc[valid_mask].reset_index(drop=True)
    y_valid = y_all[valid_mask]
    X_valid = X_all

    train_mask = df_valid["split"].astype(str) == "train"
    val_mask = df_valid["split"].astype(str) == "val"
    test_mask = df_valid["split"].astype(str) == "test"
    if int(train_mask.sum()) < 2:
        raise ValueError("Not enough training rows after featurization.")

    X_train, y_train = X_valid[train_mask.values], y_valid[train_mask.values]

    # Kernel options are limited in the UI; validate anyway.
    k = str(kernel).strip().lower()
    if k not in {"rbf", "linear"}:
        k = "rbf"

    model = SVR(kernel=k, C=float(C), epsilon=float(epsilon))
    model.fit(X_train, y_train)

    def _eval(mask: np.ndarray) -> dict[str, float] | None:
        if int(mask.sum()) == 0:
            return None
        Xs, ys = X_valid[mask], y_valid[mask]
        yhat = model.predict(Xs)
        rmse = float(np.sqrt(mean_squared_error(ys, yhat)))
        mae = float(mean_absolute_error(ys, yhat))
        r2 = float(r2_score(ys, yhat)) if len(np.unique(ys)) > 1 else 0.0
        pearson_r = float(pearsonr(ys, yhat)[0]) if len(np.unique(ys)) > 1 else None
        spearman_r = float(spearmanr(ys, yhat)[0]) if len(np.unique(ys)) > 1 else None
        return {"rmse": rmse, "mae": mae, "r2": r2, "pearson_r": pearson_r, "spearman_r": spearman_r}

    splits: dict[str, dict[str, float]] = {}
    for name, mask in (
        ("train", train_mask.values),
        ("val", val_mask.values),
        ("test", test_mask.values),
    ):
        m = _eval(mask)
        if m is not None:
            splits[name] = m

    metrics = {
        "model_type": "SVR",
        "task_type": "regression",
        "params": {
            "C": float(C),
            "epsilon": float(epsilon),
            "kernel": k,
            "random_seed": int(random_seed),
        },
        "splits": splits,
    }

    model_path = run_dir / "model_svr.pkl"
    metrics_path = run_dir / "model_metrics_svr.json"
    predictions_path = run_dir / "model_predictions_svr.csv"

    joblib.dump(model, model_path)
    metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")

    pred_rows = df_valid.loc[val_mask | test_mask].copy()
    if not pred_rows.empty:
        X_pred = X_valid[(val_mask | test_mask).values]
        pred_rows["y_pred"] = model.predict(X_pred)
    pred_rows.to_csv(predictions_path, index=False)

    return SvrTrainResult(
        run_dir=run_dir,
        task_type="regression",
        model_path=model_path,
        metrics_path=metrics_path,
        predictions_path=predictions_path,
        metrics=metrics,
    )

Graph Neural Networks

GnnTrainConfig dataclass

GnnTrainConfig(architecture: Literal['gin', 'gcn', 'gat', 'sage', 'pna', 'transformer', 'tag', 'arma', 'cheb', 'supergat'] = 'gin', embedding_dim: int = 128, hidden_dim: int = 128, num_layers: int = 5, dropout: float = 0.1, pooling: Literal['add', 'mean', 'max', 'attention'] = 'add', residual: bool = False, head_mlp_layers: int = 2, gin_conv_mlp_layers: int = 2, gin_train_eps: bool = False, gin_eps: float = 0.0, gat_heads: int = 4, sage_aggr: str = 'mean', transformer_heads: int = 4, transformer_edge_dim: int | None = None, tag_k: int = 2, arma_num_stacks: int = 1, arma_num_layers: int = 1, cheb_k: int = 2, supergat_heads: int = 4, supergat_attention_type: str = 'MX', lr: float = 0.001, weight_decay: float = 0.0, batch_size: int = 64, epochs: int = 10, random_seed: int = 42, device: str | None = None)

GnnTrainResult dataclass

GnnTrainResult(run_dir: Path, task_type: Literal['regression'], model_path: Path, encoder_path: Path, encoder_config_path: Path, metrics_path: Path, predictions_path: Path, metrics: dict[str, Any], best_epoch: int | None = None)

train_gnn_on_run

train_gnn_on_run(run_dir: str | Path, *, config: GnnTrainConfig | None = None, wandb_run=None) -> GnnTrainResult

Train a 2D GNN model on the run's dataset.csv, using molecule graphs from compounds.csv.

Writes
  • model_gnn_.pt
  • encoder_.pt
  • encoder__config.json
  • model_metrics_gnn_.json
  • model_predictions_gnn_.csv
Source code in src/dta_gnn/models/gnn.py
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
def train_gnn_on_run(
    run_dir: str | Path,
    *,
    config: GnnTrainConfig | None = None,
    wandb_run=None,
) -> GnnTrainResult:
    """Train a 2D GNN model on the run's dataset.csv, using molecule graphs from compounds.csv.

    Writes:
      - model_gnn_<architecture>.pt
      - encoder_<architecture>.pt
      - encoder_<architecture>_config.json
      - model_metrics_gnn_<architecture>.json
      - model_predictions_gnn_<architecture>.csv
    """

    _require_pyg()
    import torch
    from torch_geometric.loader import DataLoader

    cfg = config or GnnTrainConfig()
    set_seed(cfg.random_seed)

    run_path = _resolve_run_dir(run_dir)
    dataset_path = run_path / "dataset.csv"
    compounds_path = run_path / "compounds.csv"

    if not dataset_path.exists():
        raise ValueError(f"Missing dataset.csv in run folder: {dataset_path}")
    if not compounds_path.exists():
        raise ValueError(f"Missing compounds.csv in run folder: {compounds_path}")

    df = pd.read_csv(dataset_path)
    required = {"molecule_chembl_id", "label", "split"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"dataset.csv missing columns: {sorted(missing)}")

    df_comp = pd.read_csv(compounds_path)
    if "molecule_chembl_id" not in df_comp.columns or "smiles" not in df_comp.columns:
        raise ValueError(
            "compounds.csv must contain 'molecule_chembl_id' and 'smiles'."
        )

    meta = _load_json(run_path / "metadata.json")
    task_type = _infer_task_type(df, meta)

    # Map molecule id -> smiles
    df_comp_clean = (
        df_comp.dropna(subset=["molecule_chembl_id", "smiles"])
        .drop_duplicates(subset=["molecule_chembl_id"], keep="first")
    )
    smiles_map = (
        df_comp_clean.set_index("molecule_chembl_id")["smiles"]
        .astype(str)
        .to_dict()
    )

    df = df.copy()
    df["_mid"] = df["molecule_chembl_id"].astype(str)
    df["_smiles"] = df["_mid"].map(smiles_map)
    df = df.dropna(subset=["_smiles"]).reset_index(drop=True)
    if df.empty:
        raise ValueError("No rows left after joining SMILES from compounds.csv")

    # Build one graph per dataset row (simple MVP, matches current RF 'molecule-only' baseline).
    # Keep df aligned by building graphs row-by-row.
    from torch_geometric.data import Data

    data_list: list[Data] = []

    for _, row in df.iterrows():
        mid = str(row["_mid"])
        smi = str(row["_smiles"])
        g = smiles_to_graph_2d(molecule_chembl_id=mid, smiles=smi)
        if g is None:
            continue
        data_list.append(
            Data(
                atom_type=torch.from_numpy(g.atom_type.astype(np.int64)),
                atom_feat=torch.from_numpy(g.atom_feat.astype(np.float32)),
                edge_index=torch.from_numpy(g.edge_index.astype(np.int64)),
                edge_attr=torch.from_numpy(g.edge_attr.astype(np.float32)),
                y=torch.tensor(float(row["label"]), dtype=torch.float32),
                split=str(row["split"]),
                molecule_chembl_id=str(row.get("molecule_chembl_id", "")),
                target_chembl_id=str(row.get("target_chembl_id", "")),
                num_nodes=int(g.atom_feat.shape[0]),
            )
        )

    if not data_list:
        raise ValueError("No valid SMILES could be converted to graphs.")

    train_data = [d for d in data_list if getattr(d, "split") == "train"]
    val_data = [d for d in data_list if getattr(d, "split") == "val"]
    test_data = [d for d in data_list if getattr(d, "split") == "test"]

    if not train_data:
        raise ValueError("No training rows found (split == 'train').")

    edge_dim = (
        int(train_data[0].edge_attr.shape[1]) if train_data[0].edge_attr.numel() else 6
    )

    pna_deg = None
    pna_deg_list: list[int] | None = None
    if str(cfg.architecture).strip().lower() == "pna":
        from torch_geometric.utils import degree

        # Build a degree histogram across the training graphs.
        # PNAConv expects a histogram tensor where index i holds the count of nodes with degree i.
        hist: torch.Tensor | None = None
        for d in train_data:
            deg_vec = degree(
                d.edge_index[1], num_nodes=int(d.num_nodes), dtype=torch.long
            )
            deg_hist = torch.bincount(deg_vec, minlength=int(deg_vec.max().item()) + 1)
            hist = (
                deg_hist
                if hist is None
                else (
                    torch.nn.functional.pad(
                        hist, (0, max(0, deg_hist.numel() - hist.numel()))
                    )
                    + torch.nn.functional.pad(
                        deg_hist, (0, max(0, hist.numel() - deg_hist.numel()))
                    )
                )
            )

        pna_deg = (
            hist if hist is not None else torch.tensor([0], dtype=torch.long)
        ).cpu()
        pna_deg_list = [int(x) for x in pna_deg.tolist()]

    GnnEncoder, GnnPredictor = _make_encoder_and_model(
        embedding_dim=int(cfg.embedding_dim),
        hidden_dim=int(cfg.hidden_dim),
        num_layers=int(cfg.num_layers),
        dropout=float(cfg.dropout),
        edge_dim=edge_dim,
        architecture=str(cfg.architecture),
        pooling=str(cfg.pooling),
        residual=bool(cfg.residual),
        head_mlp_layers=int(cfg.head_mlp_layers),
        gin_conv_mlp_layers=int(cfg.gin_conv_mlp_layers),
        gin_train_eps=bool(cfg.gin_train_eps),
        gin_eps=float(cfg.gin_eps),
        gat_heads=int(cfg.gat_heads),
        sage_aggr=str(cfg.sage_aggr),
        pna_deg=pna_deg,
        transformer_heads=int(cfg.transformer_heads),
        transformer_edge_dim=cfg.transformer_edge_dim,
        tag_k=int(cfg.tag_k),
        arma_num_stacks=int(cfg.arma_num_stacks),
        arma_num_layers=int(cfg.arma_num_layers),
        cheb_k=int(cfg.cheb_k),
        supergat_heads=int(cfg.supergat_heads),
        supergat_attention_type=str(cfg.supergat_attention_type),
    )

    model = GnnPredictor()
    device = _get_device(cfg.device)

    # TransformerConv doesn't support MPS (scatter_reduce not implemented)
    # Fall back to CPU if transformer architecture is used on MPS
    if str(cfg.architecture).strip().lower() == "transformer" and str(device) == "mps":
        import torch
        import warnings
        import sys
        import os

        # Only show warning if not running in pytest (to reduce test noise)
        # Check both sys.modules and environment variable
        is_pytest = (
            "pytest" in sys.modules
            or os.environ.get("PYTEST_CURRENT_TEST") is not None
            or any("pytest" in arg for arg in sys.argv)
        )

        if not is_pytest:
            warnings.warn(
                "TransformerConv doesn't support MPS. Falling back to CPU. "
                "For better performance, use device='cpu' explicitly or use a different architecture.",
                UserWarning,
            )
        device = torch.device("cpu")

    model.to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=float(cfg.lr),
        weight_decay=float(cfg.weight_decay),
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10, min_lr=1e-6,
    )

    criterion = torch.nn.MSELoss()

    train_loader = DataLoader(train_data, batch_size=int(cfg.batch_size), shuffle=True)
    val_loader = (
        DataLoader(val_data, batch_size=int(cfg.batch_size), shuffle=False)
        if val_data
        else None
    )
    test_loader = (
        DataLoader(test_data, batch_size=int(cfg.batch_size), shuffle=False)
        if test_data
        else None
    )

    torch.manual_seed(int(cfg.random_seed))

    def _compute_val_metrics(model, val_loader, device, criterion):
        """Compute validation metrics during training."""
        model.eval()
        val_losses = []
        y_true_list = []
        y_pred_list = []

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                logits, _ = model(
                    batch.atom_type,
                    batch.atom_feat,
                    batch.edge_index,
                    batch.edge_attr,
                    batch.batch,
                )
                y = batch.y.view(-1).float()
                loss = criterion(logits.view(-1), y)
                val_losses.append(loss.item())

                y_true_list.extend(y.cpu().numpy().tolist())
                y_pred_list.extend(logits.view(-1).cpu().numpy().tolist())

        model.train()

        if not y_true_list:
            return None

        yt = np.asarray(y_true_list, dtype=float)
        yp = np.asarray(y_pred_list, dtype=float)

        if np.any(np.isnan(yp)) or np.any(np.isinf(yp)):
            return {
                "loss": float("inf"),
                "rmse": float("inf"),
                "mae": float("inf"),
                "r2": -float("inf"),
                "pearson_r": 0.0,
                "spearman_r": 0.0,
            }

        from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
        from scipy.stats import pearsonr, spearmanr

        rmse = float(np.sqrt(mean_squared_error(yt, yp)))
        mae = float(mean_absolute_error(yt, yp))
        r2 = float(r2_score(yt, yp)) if yt.size >= 2 else None
        pearson_r = float(pearsonr(yt, yp)[0]) if yt.size >= 2 else None
        spearman_r = float(spearmanr(yt, yp)[0]) if yt.size >= 2 else None

        return {
            "loss": sum(val_losses) / len(val_losses) if val_losses else 0.0,
            "rmse": rmse,
            "mae": mae,
            "r2": r2,
            "pearson_r": pearson_r,
            "spearman_r": spearman_r,
        }

    # Initialize best model tracking
    best_val_score = float('-inf')
    best_epoch = -1
    best_model_state = None
    best_encoder_state = None
    best_val_metrics = None  # Store best validation metrics for logging

    logger.info(f"Training GNN for {cfg.epochs} epochs (batch_size={cfg.batch_size}, lr={cfg.lr:.6f})...")
    for epoch in range(int(cfg.epochs)):
        model.train()
        epoch_losses = []
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad(set_to_none=True)
            logits, _z = model(
                batch.atom_type,
                batch.atom_feat,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
            )
            y = batch.y.view(-1).float()
            loss = criterion(logits.view(-1), y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_losses.append(loss.item())

        avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0.0

        # Evaluate on validation set if available
        val_loss = None
        val_metrics = None
        if val_loader is not None:
            val_metrics = _compute_val_metrics(model, val_loader, device, criterion)
            if val_metrics:
                val_loss = val_metrics["loss"]

                # Use R² if available, otherwise use -RMSE (so higher is better)
                current_score = val_metrics["r2"] if val_metrics["r2"] is not None else -val_metrics["rmse"]

                if current_score > best_val_score:
                    best_val_score = current_score
                    best_epoch = epoch + 1
                    # Deep-copy state so later epochs don't overwrite it (state_dict().copy() is shallow)
                    best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                    best_encoder_state = {k: v.cpu().clone() for k, v in model.encoder.state_dict().items()}
                    best_val_metrics = val_metrics.copy()
                    score_str = f"R²={val_metrics['r2']:.4f}" if val_metrics["r2"] is not None else f"RMSE={val_metrics['rmse']:.4f}"
                    logger.info(f"New best model at epoch {best_epoch} (val_{score_str})")

        if val_loss is not None:
            scheduler.step(val_loss)

        # Log to wandb if available
        if wandb_run is not None:
            log_dict = {
                "train/loss": avg_loss,
                "epoch": epoch + 1,
            }
            if val_loss is not None:
                log_dict["val/loss"] = val_loss
            if val_metrics is not None:
                log_dict["val/rmse"] = val_metrics["rmse"]
                log_dict["val/mae"] = val_metrics["mae"]
                log_dict["val/pearson_r"] = val_metrics["pearson_r"]
                log_dict["val/spearman_r"] = val_metrics["spearman_r"]
                if val_metrics["r2"] is not None:
                    log_dict["val/r2"] = val_metrics["r2"]
            wandb_run.log(log_dict)

        if (epoch + 1) % max(1, int(cfg.epochs) // 5) == 0 or epoch == 0:
            val_str = f", val_loss: {val_loss:.4f}" if val_loss is not None else ""
            logger.info(f"Epoch {epoch + 1}/{cfg.epochs} completed (train_loss: {avg_loss:.4f}{val_str})")

    # Load best model if available
    if best_model_state is not None and val_loader is not None and best_val_metrics is not None:
        score_str = f"R²={best_val_metrics['r2']:.4f}" if best_val_metrics.get("r2") is not None else f"RMSE={best_val_metrics['rmse']:.4f}"
        logger.info(f"Loading best model from epoch {best_epoch} (val_{score_str})...")
        model.load_state_dict(best_model_state)
        model.encoder.load_state_dict(best_encoder_state)
    else:
        logger.info("Using final epoch model (no validation set for checkpointing)")

    logger.info("Training complete. Evaluating...")
    # Evaluation + predictions
    model.eval()

    def _predict(loader):
        ys: list[float] = []
        preds: list[float] = []
        probs: list[float] = []
        mols: list[str] = []
        targs: list[str] = []
        splits: list[str] = []

        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                logits, _z = model(
                    batch.atom_type,
                    batch.atom_feat,
                    batch.edge_index,
                    batch.edge_attr,
                    batch.batch,
                )
                y_true = batch.y.view(-1).float().cpu().numpy()
                y_score = logits.view(-1).float().cpu().numpy()

                ys.extend(y_true.tolist())
                splits.extend([str(s) for s in batch.split])
                mols.extend([str(m) for m in batch.molecule_chembl_id])
                targs.extend([str(t) for t in batch.target_chembl_id])

                preds.extend(y_score.astype(float).tolist())
        return ys, preds, [], mols, targs, splits

    metrics: dict[str, Any] = {
        "task_type": task_type,
        "model": {
            "type": "gnn_2d_pyg",
            "architecture": str(cfg.architecture),
            "embedding_dim": int(cfg.embedding_dim),
            "hidden_dim": int(cfg.hidden_dim),
            "num_layers": int(cfg.num_layers),
            "dropout": float(cfg.dropout),
            "pooling": str(cfg.pooling),
            "residual": bool(cfg.residual),
            "head_mlp_layers": int(cfg.head_mlp_layers),
            "gin_conv_mlp_layers": int(cfg.gin_conv_mlp_layers),
            "gin_train_eps": bool(cfg.gin_train_eps),
            "gin_eps": float(cfg.gin_eps),
            "gat_heads": int(cfg.gat_heads),
            "sage_aggr": str(cfg.sage_aggr),
            "pna_deg_hist": pna_deg_list,
            "transformer_heads": int(cfg.transformer_heads),
            "transformer_edge_dim": cfg.transformer_edge_dim,
            "tag_k": int(cfg.tag_k),
            "arma_num_stacks": int(cfg.arma_num_stacks),
            "arma_num_layers": int(cfg.arma_num_layers),
            "cheb_k": int(cfg.cheb_k),
            "supergat_heads": int(cfg.supergat_heads),
            "supergat_attention_type": str(cfg.supergat_attention_type),
            "epochs": int(cfg.epochs),
            "batch_size": int(cfg.batch_size),
            "lr": float(cfg.lr),
            "weight_decay": float(cfg.weight_decay),
            "random_seed": int(cfg.random_seed),
            "best_epoch": best_epoch if best_epoch > 0 else None,  # Epoch number where best model was saved
        },
        "created_at": datetime.now(timezone.utc).isoformat(),
        "dta_gnn_version": __version__,
        "splits": {},
    }

    from sklearn.metrics import (
        mean_squared_error,
        mean_absolute_error,
        r2_score,
    )
    from scipy.stats import pearsonr, spearmanr
    def _split_metrics(
        y_true: list[float], y_pred: list[float], y_prob: list[float] | None
    ):
        if not y_true:
            return None
        yt = np.asarray(y_true, dtype=float)
        yp = np.asarray(y_pred, dtype=float)
        rmse = float(np.sqrt(mean_squared_error(yt, yp)))
        return {
            "n": int(yt.size),
            "rmse": rmse,
            "mae": float(mean_absolute_error(yt, yp)),
            "r2": float(r2_score(yt, yp)) if yt.size >= 2 else None,
            "pearson_r": float(pearsonr(yt, yp)[0]) if yt.size >= 2 else None,
            "spearman_r": float(spearmanr(yt, yp)[0]) if yt.size >= 2 else None,
        }

    # metrics for train/val/test
    for name, loader in [
        ("train", train_loader),
        ("val", val_loader),
        ("test", test_loader),
    ]:
        if loader is None:
            continue
        y_true, y_pred, y_prob, *_rest = _predict(loader)
        metrics["splits"][name] = _split_metrics(
            y_true, y_pred, None
        )

    # Predictions (val+test)
    pred_rows: list[dict[str, Any]] = []
    for name, loader in [("val", val_loader), ("test", test_loader)]:
        if loader is None:
            continue
        y_true, y_pred, y_prob, mols, targs, splits = _predict(loader)
        for i in range(len(y_true)):
            row: dict[str, Any] = {
                "molecule_chembl_id": mols[i],
                "target_chembl_id": targs[i],
                "label": float(y_true[i]),
                "split": splits[i],
                "y_pred": float(y_pred[i]),
            }
            pred_rows.append(row)

    df_preds = pd.DataFrame(pred_rows)

    # Write artifacts - use architecture name in filenames
    arch_name = str(cfg.architecture).strip().lower()
    model_path = run_path / f"model_gnn_{arch_name}.pt"
    encoder_path = run_path / f"encoder_{arch_name}.pt"
    encoder_config_path = run_path / f"encoder_{arch_name}_config.json"
    metrics_path = run_path / f"model_metrics_gnn_{arch_name}.json"
    predictions_path = run_path / f"model_predictions_gnn_{arch_name}.csv"

    torch.save(model.state_dict(), model_path)
    torch.save(model.encoder.state_dict(), encoder_path)

    encoder_cfg = {
        "created_at": metrics["created_at"],
        "dta_gnn_version": __version__,
        "featurizer": {
            "type": "rdkit_2d_graph",
            "atom_feat_dim": 6,
            "bond_feat_dim": 6,
            "atom_type_vocab": 101,
        },
        "encoder": {
            "type": "gnn_2d_pyg",
            "architecture": str(cfg.architecture),
            "embedding_dim": int(cfg.embedding_dim),
            "hidden_dim": int(cfg.hidden_dim),
            "num_layers": int(cfg.num_layers),
            "dropout": float(cfg.dropout),
            "pooling": str(cfg.pooling),
            "residual": bool(cfg.residual),
            "head_mlp_layers": int(cfg.head_mlp_layers),
            "gin_conv_mlp_layers": int(cfg.gin_conv_mlp_layers),
            "gin_train_eps": bool(cfg.gin_train_eps),
            "gin_eps": float(cfg.gin_eps),
            "gat_heads": int(cfg.gat_heads),
            "sage_aggr": str(cfg.sage_aggr),
            "pna_deg_hist": pna_deg_list,
            "transformer_heads": int(cfg.transformer_heads),
            "transformer_edge_dim": cfg.transformer_edge_dim,
            "tag_k": int(cfg.tag_k),
            "arma_num_stacks": int(cfg.arma_num_stacks),
            "arma_num_layers": int(cfg.arma_num_layers),
            "cheb_k": int(cfg.cheb_k),
            "supergat_heads": int(cfg.supergat_heads),
            "supergat_attention_type": str(cfg.supergat_attention_type),
        },
        "reproducibility": {"random_seed": int(cfg.random_seed)},
    }
    encoder_config_path.write_text(
        json.dumps(encoder_cfg, indent=2, sort_keys=True) + "\n"
    )

    metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True) + "\n")
    df_preds.to_csv(predictions_path, index=False)

    return GnnTrainResult(
        run_dir=run_path,
        task_type=task_type,
        model_path=model_path,
        encoder_path=encoder_path,
        encoder_config_path=encoder_config_path,
        metrics_path=metrics_path,
        predictions_path=predictions_path,
        metrics=metrics,
        best_epoch=best_epoch if best_epoch > 0 else None,  # Epoch number where best model was saved
    )

GnnEmbeddingExtractResult dataclass

GnnEmbeddingExtractResult(run_dir: Path, embeddings_path: Path, n_molecules: int, embedding_dim: int)

extract_gnn_embeddings_on_run

extract_gnn_embeddings_on_run(run_dir: str | Path, *, batch_size: int = 256, device: str | None = None) -> GnnEmbeddingExtractResult

Use a saved GNN encoder to generate embeddings for molecules in compounds.csv.

Requires
  • compounds.csv
  • encoder_.pt
  • encoder__config.json
Writes
  • molecule_embeddings.npz (molecule_chembl_id, embeddings)
Source code in src/dta_gnn/models/gnn.py
def extract_gnn_embeddings_on_run(
    run_dir: str | Path,
    *,
    batch_size: int = 256,
    device: str | None = None,
) -> GnnEmbeddingExtractResult:
    """Use a saved GNN encoder to generate embeddings for molecules in compounds.csv.

    Requires:
      - compounds.csv
      - encoder_<architecture>.pt
      - encoder_<architecture>_config.json

    Writes:
      - molecule_embeddings.npz (molecule_chembl_id, embeddings)
    """

    _require_pyg()
    import torch
    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader

    run_path = _resolve_run_dir(run_dir)
    compounds_path = run_path / "compounds.csv"

    if not compounds_path.exists():
        raise ValueError(f"Missing compounds.csv in run folder: {compounds_path}")

    # Try to find encoder files by checking common architecture names
    # or by reading metadata.json if available
    encoder_path = None
    encoder_config_path = None
    arch_name = None

    # First, try to get architecture from metadata.json if it was set by training
    metadata_path = run_path / "metadata.json"
    if metadata_path.exists():
        try:
            meta = _load_json(metadata_path) or {}
            model_meta = meta.get("model") if isinstance(meta.get("model"), dict) else None
            if model_meta and "architecture" in model_meta:
                arch_name = str(model_meta.get("architecture", "")).strip().lower()
                if arch_name:
                    encoder_path = run_path / f"encoder_{arch_name}.pt"
                    encoder_config_path = run_path / f"encoder_{arch_name}_config.json"
        except Exception:
            pass

    # If not found from metadata (e.g. metadata from dataset build has no model.architecture),
    # probe for existing encoder_<arch>.pt / encoder_<arch>_config.json
    if not encoder_path or not encoder_config_path or not encoder_path.exists() or not encoder_config_path.exists():
        arch_name = None
        encoder_path = None
        encoder_config_path = None
        for arch in ["gin", "gcn", "gat", "sage", "pna", "transformer", "tag", "arma", "cheb", "supergat"]:
            test_encoder = run_path / f"encoder_{arch}.pt"
            test_config = run_path / f"encoder_{arch}_config.json"
            if test_encoder.exists() and test_config.exists():
                encoder_path = test_encoder
                encoder_config_path = test_config
                arch_name = arch
                break

    # Fallback to gin if nothing found
    if not encoder_path:
        arch_name = "gin"
        encoder_path = run_path / f"encoder_{arch_name}.pt"
        encoder_config_path = run_path / f"encoder_{arch_name}_config.json"

    if not encoder_path.exists() or not encoder_config_path.exists():
        raise ValueError(
            f"Missing encoder artifacts. Train the GNN model first to create "
            f"encoder_{arch_name}.pt and encoder_{arch_name}_config.json."
        )

    cfg = _load_json(encoder_config_path) or {}
    enc_cfg = cfg.get("encoder") if isinstance(cfg.get("encoder"), dict) else {}
    feat_cfg = cfg.get("featurizer") if isinstance(cfg.get("featurizer"), dict) else {}

    enc_type = str(enc_cfg.get("type") or "")
    architecture = str(enc_cfg.get("architecture") or "")
    if not architecture and enc_type in {"gin_pyg", "gin"}:
        architecture = "gin"
    if not architecture:
        architecture = "gin"

    pooling = str(enc_cfg.get("pooling") or "add")
    if pooling not in {"add", "mean", "max", "attention"}:
        pooling = "add"
    residual = bool(enc_cfg.get("residual") or False)
    head_mlp_layers = int(enc_cfg.get("head_mlp_layers") or 2)
    gin_conv_mlp_layers = int(enc_cfg.get("gin_conv_mlp_layers") or 2)
    gin_train_eps = bool(enc_cfg.get("gin_train_eps") or False)
    gin_eps = float(enc_cfg.get("gin_eps") or 0.0)
    gat_heads = int(enc_cfg.get("gat_heads") or 4)
    sage_aggr = str(enc_cfg.get("sage_aggr") or "mean")
    pna_deg_hist = enc_cfg.get("pna_deg_hist")
    transformer_heads = int(enc_cfg.get("transformer_heads") or 4)
    transformer_edge_dim = enc_cfg.get("transformer_edge_dim")
    tag_k = int(enc_cfg.get("tag_k") or 2)  # Default reduced from 3 for better performance
    arma_num_stacks = int(enc_cfg.get("arma_num_stacks") or 1)
    arma_num_layers = int(enc_cfg.get("arma_num_layers") or 1)
    cheb_k = int(enc_cfg.get("cheb_k") or 2)
    supergat_heads = int(enc_cfg.get("supergat_heads") or 4)
    supergat_attention_type = str(enc_cfg.get("supergat_attention_type") or "MX")
    pna_deg = None
    if architecture.strip().lower() == "pna":
        if isinstance(pna_deg_hist, list) and pna_deg_hist:
            pna_deg = torch.tensor([int(x) for x in pna_deg_hist], dtype=torch.long)
        else:
            raise ValueError(
                "encoder_config.json missing pna_deg_hist needed for PNA embeddings."
            )

    embedding_dim = int(enc_cfg.get("embedding_dim") or 128)
    hidden_dim = int(enc_cfg.get("hidden_dim") or 128)
    num_layers = int(enc_cfg.get("num_layers") or 5)
    dropout = float(enc_cfg.get("dropout") or 0.0)

    num_atom_types = int(feat_cfg.get("atom_type_vocab") or 101)

    df_comp = pd.read_csv(compounds_path)
    if "molecule_chembl_id" not in df_comp.columns or "smiles" not in df_comp.columns:
        raise ValueError(
            "compounds.csv must contain 'molecule_chembl_id' and 'smiles'."
        )

    df_comp = (
        df_comp[["molecule_chembl_id", "smiles"]]
        .dropna()
        .drop_duplicates(subset=["molecule_chembl_id"], keep="first")
        .reset_index(drop=True)
    )
    if df_comp.empty:
        raise ValueError("compounds.csv has no molecules with SMILES.")

    graphs = build_graphs_2d(
        molecules=list(
            zip(
                df_comp["molecule_chembl_id"].astype(str), df_comp["smiles"].astype(str)
            )
        ),
        drop_failures=True,
    )
    if not graphs:
        raise ValueError("No valid SMILES could be converted to graphs.")

    edge_dim = int(graphs[0].edge_attr.shape[1]) if graphs[0].edge_attr.size else 6

    GnnEncoder, _GnnPredictor = _make_encoder_and_model(
        embedding_dim=embedding_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        edge_dim=edge_dim,
        num_atom_types=num_atom_types,
        architecture=architecture,
        pooling=pooling,
        residual=residual,
        head_mlp_layers=head_mlp_layers,
        gin_conv_mlp_layers=gin_conv_mlp_layers,
        gin_train_eps=gin_train_eps,
        gin_eps=gin_eps,
        gat_heads=gat_heads,
        sage_aggr=sage_aggr,
        pna_deg=pna_deg,
        transformer_heads=transformer_heads,
        transformer_edge_dim=transformer_edge_dim,
        tag_k=tag_k,
        arma_num_stacks=arma_num_stacks,
        arma_num_layers=arma_num_layers,
        cheb_k=cheb_k,
        supergat_heads=supergat_heads,
        supergat_attention_type=supergat_attention_type,
    )
    encoder = GnnEncoder()
    device_obj = _get_device(device)

    # TransformerConv doesn't support MPS (scatter_reduce not implemented)
    # Fall back to CPU if transformer architecture is used on MPS
    if architecture.strip().lower() == "transformer" and str(device_obj) == "mps":
        import warnings
        import sys
        import os

        # Only show warning if not running in pytest (to reduce test noise)
        is_pytest = (
            "pytest" in sys.modules
            or os.environ.get("PYTEST_CURRENT_TEST") is not None
            or any("pytest" in arg for arg in sys.argv)
        )

        if not is_pytest:
            warnings.warn(
                "TransformerConv doesn't support MPS. Falling back to CPU. "
                "For better performance, use device='cpu' explicitly.",
                UserWarning,
            )
        device_obj = torch.device("cpu")

    encoder.load_state_dict(torch.load(encoder_path, map_location=device_obj))
    encoder.to(device_obj)
    encoder.eval()

    data_list: list[Data] = []
    ids: list[str] = []
    for g in graphs:
        ids.append(str(g.molecule_chembl_id))
        data_list.append(
            Data(
                atom_type=torch.from_numpy(g.atom_type.astype(np.int64)),
                atom_feat=torch.from_numpy(g.atom_feat.astype(np.float32)),
                edge_index=torch.from_numpy(g.edge_index.astype(np.int64)),
                edge_attr=torch.from_numpy(g.edge_attr.astype(np.float32)),
                molecule_chembl_id=str(g.molecule_chembl_id),
                num_nodes=int(g.atom_feat.shape[0]),
            )
        )

    loader = DataLoader(data_list, batch_size=int(batch_size), shuffle=False)

    embs: list[np.ndarray] = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device_obj)
            z = encoder(
                batch.atom_type,
                batch.atom_feat,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
            )
            embs.append(z.detach().cpu().numpy().astype(np.float32))

    embeddings = (
        np.concatenate(embs, axis=0)
        if embs
        else np.zeros((0, embedding_dim), dtype=np.float32)
    )

    out_path = run_path / "molecule_embeddings.npz"
    np.savez_compressed(
        out_path,
        molecule_chembl_id=np.asarray(ids, dtype=object),
        embeddings=embeddings,
    )

    return GnnEmbeddingExtractResult(
        run_dir=run_path,
        embeddings_path=out_path,
        n_molecules=int(embeddings.shape[0]),
        embedding_dim=int(embeddings.shape[1])
        if embeddings.ndim == 2
        else int(embedding_dim),
    )

Prediction on new molecules

PredictionResult dataclass

PredictionResult(predictions: DataFrame, model_type: Literal['RandomForest', 'SVR', 'GNN'], model_path: str, run_dir: Path)

Result from model prediction.

predict_with_random_forest

predict_with_random_forest(run_dir: Path, smiles_list: list[str], molecule_ids: list[str] | None = None) -> PredictionResult

Predict using a trained RandomForest model.

Parameters:

Name Type Description Default
run_dir Path

Directory containing the trained model (model_rf.pkl)

required
smiles_list list[str]

List of SMILES strings to predict

required
molecule_ids list[str] | None

Optional list of molecule IDs (defaults to mol_0, mol_1, ...)

None

Returns:

Type Description
PredictionResult

PredictionResult with predictions DataFrame

Source code in src/dta_gnn/models/predict.py
def predict_with_random_forest(
    run_dir: Path,
    smiles_list: list[str],
    molecule_ids: list[str] | None = None,
) -> PredictionResult:
    """Predict using a trained RandomForest model.

    Args:
        run_dir: Directory containing the trained model (model_rf.pkl)
        smiles_list: List of SMILES strings to predict
        molecule_ids: Optional list of molecule IDs (defaults to mol_0, mol_1, ...)

    Returns:
        PredictionResult with predictions DataFrame
    """
    import joblib
    from rdkit import Chem
    from rdkit.Chem import AllChem

    run_dir = Path(run_dir).resolve()
    model_path = run_dir / "model_rf.pkl"

    if not model_path.exists():
        raise FileNotFoundError(f"RandomForest model not found: {model_path}")

    # Load model
    model = joblib.load(model_path)

    # Generate molecule IDs if not provided
    if molecule_ids is None:
        molecule_ids = [f"mol_{i}" for i in range(len(smiles_list))]

    if len(molecule_ids) != len(smiles_list):
        raise ValueError(
            f"Length mismatch: {len(smiles_list)} SMILES but {len(molecule_ids)} IDs"
        )

    # Build fingerprints
    def get_fp(smi: str) -> np.ndarray | None:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                return None
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            return np.array(fp)
        except Exception:
            return None

    fps = []
    valid_idx = []
    for idx, smi in enumerate(smiles_list):
        fp = get_fp(smi)
        if fp is not None:
            fps.append(fp)
            valid_idx.append(idx)

    if not fps:
        raise ValueError("No valid SMILES to generate predictions.")

    X = np.array(fps)

    # Predict
    try:
        proba = model.predict_proba(X)[:, 1]
        preds = model.predict(X)
    except Exception:
        # Regression case
        preds = model.predict(X)
        proba = preds  # For regression, use prediction as score

    # Build results DataFrame
    results = []
    for i, orig_idx in enumerate(valid_idx):
        results.append(
            {
                "molecule_id": molecule_ids[orig_idx],
                "smiles": smiles_list[orig_idx],
                "prediction": float(preds[i]),
            }
        )

    # Add failed molecules
    for idx in range(len(smiles_list)):
        if idx not in valid_idx:
            results.append(
                {
                    "molecule_id": molecule_ids[idx],
                    "smiles": smiles_list[idx],
                    "prediction": None,
                }
            )

    df = pd.DataFrame(results)

    return PredictionResult(
        predictions=df,
        model_type="RandomForest",
        model_path=str(model_path),
        run_dir=run_dir,
    )

predict_with_svr

predict_with_svr(run_dir: Path, smiles_list: list[str], molecule_ids: list[str] | None = None) -> PredictionResult

Predict using a trained SVR model.

Expects model_svr.pkl in the run directory. Uses Morgan (ECFP4) fingerprints with radius=2, nBits=2048.

Source code in src/dta_gnn/models/predict.py
def predict_with_svr(
    run_dir: Path,
    smiles_list: list[str],
    molecule_ids: list[str] | None = None,
) -> PredictionResult:
    """Predict using a trained SVR model.

    Expects `model_svr.pkl` in the run directory.
    Uses Morgan (ECFP4) fingerprints with radius=2, nBits=2048.
    """

    import joblib
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem

    run_dir = Path(run_dir).resolve()
    model_path = run_dir / "model_svr.pkl"
    if not model_path.exists():
        raise FileNotFoundError(f"SVR model not found: {model_path}")

    model = joblib.load(model_path)

    if molecule_ids is None:
        molecule_ids = [f"mol_{i}" for i in range(len(smiles_list))]
    if len(molecule_ids) != len(smiles_list):
        raise ValueError(
            f"Length mismatch: {len(smiles_list)} SMILES but {len(molecule_ids)} IDs"
        )

    def _fp(smi: str) -> np.ndarray | None:
        try:
            mol = Chem.MolFromSmiles(str(smi))
            if mol is None:
                return None
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
            arr = np.zeros((2048,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(fp, arr)
            return arr
        except Exception:
            return None

    fps: list[np.ndarray] = []
    valid_idx: list[int] = []
    for idx, smi in enumerate(smiles_list):
        arr = _fp(smi)
        if arr is not None:
            fps.append(arr)
            valid_idx.append(idx)

    if not fps:
        raise ValueError("No valid SMILES to generate predictions.")

    X = np.asarray(fps, dtype=np.float32)
    y_pred = model.predict(X).astype(float)

    results: list[dict[str, object]] = []
    for i, orig_idx in enumerate(valid_idx):
        results.append(
            {
                "molecule_id": molecule_ids[orig_idx],
                "smiles": smiles_list[orig_idx],
                "prediction": float(y_pred[i]),
            }
        )

    for idx in range(len(smiles_list)):
        if idx not in valid_idx:
            results.append(
                {
                    "molecule_id": molecule_ids[idx],
                    "smiles": smiles_list[idx],
                    "prediction": None,
                }
            )

    df = pd.DataFrame(results)
    return PredictionResult(
        predictions=df,
        model_type="SVR",
        model_path=str(model_path),
        run_dir=run_dir,
    )

predict_with_gnn

predict_with_gnn(run_dir: Path, smiles_list: list[str], molecule_ids: list[str] | None = None, batch_size: int = 64, device: str | None = None, architecture: str | None = None) -> PredictionResult

Predict using a trained GNN model.

Parameters:

Name Type Description Default
run_dir Path

Directory containing the trained model (model_gnn_.pt)

required
smiles_list list[str]

List of SMILES strings to predict

required
molecule_ids list[str] | None

Optional list of molecule IDs (defaults to mol_0, mol_1, ...)

None
batch_size int

Batch size for inference

64

Returns:

Type Description
PredictionResult

PredictionResult with predictions DataFrame

Source code in src/dta_gnn/models/predict.py
def predict_with_gnn(
    run_dir: Path,
    smiles_list: list[str],
    molecule_ids: list[str] | None = None,
    batch_size: int = 64,
    device: str | None = None,
    architecture: str | None = None,
) -> PredictionResult:
    """Predict using a trained GNN model.

    Args:
        run_dir: Directory containing the trained model (model_gnn_<architecture>.pt)
        smiles_list: List of SMILES strings to predict
        molecule_ids: Optional list of molecule IDs (defaults to mol_0, mol_1, ...)
        batch_size: Batch size for inference

    Returns:
        PredictionResult with predictions DataFrame
    """
    from dta_gnn.features.molecule_graphs import smiles_to_graph_2d
    from dta_gnn.models.gnn import _make_encoder_and_model, _get_device, _load_json

    try:
        import torch
        from torch_geometric.data import Data
        from torch_geometric.loader import DataLoader
    except ImportError:
        raise ImportError(
            "PyTorch Geometric not installed. Install with: pip install torch torch-geometric"
        )

    run_dir = Path(run_dir).resolve()

    # Try to find model file by checking common architecture names
    model_path = None
    config_path = None
    arch_name = None

    # If architecture is explicitly provided, use it
    if architecture:
        arch_name = str(architecture).strip().lower()
        model_path = run_dir / f"model_gnn_{arch_name}.pt"
        config_path = run_dir / f"encoder_{arch_name}_config.json"

    # If not provided or files don't exist, try to detect
    if not arch_name or not model_path.exists():
        # First, try to get architecture from metadata.json if it exists
        metadata_path = run_dir / "metadata.json"
        if metadata_path.exists():
            try:
                meta = _load_json(metadata_path) or {}
                arch_name = str(
                    (meta.get("model", {}) or {}).get("architecture", "gin")
                ).strip().lower()
            except Exception:
                pass

        # If not found, try common architectures
        if not arch_name:
            for arch in ["gin", "gcn", "gat", "sage", "pna", "transformer", "tag", "arma", "cheb", "supergat"]:
                test_model = run_dir / f"model_gnn_{arch}.pt"
                test_config = run_dir / f"encoder_{arch}_config.json"
                if test_model.exists() and test_config.exists():
                    model_path = test_model
                    config_path = test_config
                    arch_name = arch
                    break

        # Fallback to gin if nothing found
        if not model_path or not model_path.exists():
            arch_name = "gin"
            model_path = run_dir / f"model_gnn_{arch_name}.pt"
            config_path = run_dir / f"encoder_{arch_name}_config.json"

    if not model_path.exists():
        raise FileNotFoundError(f"GNN model not found: {model_path}")
    if not config_path.exists():
        raise FileNotFoundError(f"GNN config not found: {config_path}")

    # Load model config
    with open(config_path, "r") as f:
        config = json.load(f)

    encoder_cfg = config.get("encoder") or {}
    feat_cfg = config.get("featurizer") or {}

    enc_type = str((encoder_cfg or {}).get("type") or "")
    architecture = str((encoder_cfg or {}).get("architecture") or "")
    if not architecture and enc_type in {"gin_pyg", "gin"}:
        architecture = "gin"
    if not architecture:
        architecture = "gin"

    # Use architecture from config if arch_name wasn't found earlier
    if not arch_name:
        arch_name = architecture.strip().lower()
        # Update paths if we found architecture from config
        if not model_path.exists():
            model_path = run_dir / f"model_gnn_{arch_name}.pt"
            config_path = run_dir / f"encoder_{arch_name}_config.json"

    pooling = str((encoder_cfg or {}).get("pooling") or "add")
    if pooling not in {"add", "mean", "max", "attention"}:
        pooling = "add"
    residual = bool((encoder_cfg or {}).get("residual") or False)
    head_mlp_layers = int((encoder_cfg or {}).get("head_mlp_layers") or 2)
    gin_conv_mlp_layers = int((encoder_cfg or {}).get("gin_conv_mlp_layers") or 2)
    gin_train_eps = bool((encoder_cfg or {}).get("gin_train_eps") or False)
    gin_eps = float((encoder_cfg or {}).get("gin_eps") or 0.0)
    gat_heads = int((encoder_cfg or {}).get("gat_heads") or 4)
    sage_aggr = str((encoder_cfg or {}).get("sage_aggr") or "mean")
    pna_deg_hist = (encoder_cfg or {}).get("pna_deg_hist")

    pna_deg = None
    if architecture.strip().lower() == "pna":
        if isinstance(pna_deg_hist, list) and pna_deg_hist:
            pna_deg = torch.tensor([int(x) for x in pna_deg_hist], dtype=torch.long)
        else:
            raise ValueError(
                "encoder_config.json missing pna_deg_hist needed for PNA inference."
            )

    edge_dim = int((feat_cfg or {}).get("bond_feat_dim") or 6)
    num_atom_types = int((feat_cfg or {}).get("atom_type_vocab") or 101)

    # Reconstruct model architecture
    _, GinPredictor = _make_encoder_and_model(
        embedding_dim=int((encoder_cfg or {}).get("embedding_dim") or 128),
        hidden_dim=int((encoder_cfg or {}).get("hidden_dim") or 128),
        num_layers=int((encoder_cfg or {}).get("num_layers") or 5),
        dropout=float((encoder_cfg or {}).get("dropout") or 0.0),
        edge_dim=edge_dim,
        num_atom_types=num_atom_types,
        architecture=architecture,
        pooling=pooling,
        residual=residual,
        head_mlp_layers=head_mlp_layers,
        gin_conv_mlp_layers=gin_conv_mlp_layers,
        gin_train_eps=gin_train_eps,
        gin_eps=gin_eps,
        gat_heads=gat_heads,
        sage_aggr=sage_aggr,
        pna_deg=pna_deg,
    )

    model = GinPredictor()
    # Load state dict and move to device
    device_obj = _get_device(device)

    # TransformerConv doesn't support MPS (scatter_reduce not implemented)
    # Fall back to CPU if transformer architecture is used on MPS
    if architecture.strip().lower() == "transformer" and str(device_obj) == "mps":
        import warnings
        import sys
        import os

        # Only show warning if not running in pytest (to reduce test noise)
        is_pytest = (
            "pytest" in sys.modules
            or os.environ.get("PYTEST_CURRENT_TEST") is not None
            or any("pytest" in arg for arg in sys.argv)
        )

        if not is_pytest:
            warnings.warn(
                "TransformerConv doesn't support MPS. Falling back to CPU. "
                "For better performance, use device='cpu' explicitly.",
                UserWarning,
            )
        device_obj = torch.device("cpu")

    state_dict = torch.load(model_path, map_location=device_obj)
    model.load_state_dict(state_dict)
    model.to(device_obj)
    model.eval()

    # Generate molecule IDs if not provided
    if molecule_ids is None:
        molecule_ids = [f"mol_{i}" for i in range(len(smiles_list))]

    if len(molecule_ids) != len(smiles_list):
        raise ValueError(
            f"Length mismatch: {len(smiles_list)} SMILES but {len(molecule_ids)} IDs"
        )

    # Build graphs
    graphs = []
    valid_idx = []
    for idx, smi in enumerate(smiles_list):
        try:
            g = smiles_to_graph_2d(molecule_chembl_id=molecule_ids[idx], smiles=smi)
            if g is not None:
                data = Data(
                    atom_type=torch.from_numpy(g.atom_type.astype(np.int64)),
                    atom_feat=torch.from_numpy(g.atom_feat.astype(np.float32)),
                    edge_index=torch.from_numpy(g.edge_index.astype(np.int64)),
                    edge_attr=torch.from_numpy(g.edge_attr.astype(np.float32)),
                )
                graphs.append(data)
                valid_idx.append(idx)
        except Exception:
            continue

    if not graphs:
        raise ValueError("No valid SMILES to generate predictions.")

    # Predict
    loader = DataLoader(graphs, batch_size=batch_size, shuffle=False)
    all_preds = []
    all_probs = []

    # Task type is always regression for DTA
    task_type = "regression"

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device_obj)
            logits, _z = model(
                batch.atom_type,
                batch.atom_feat,
                batch.edge_index,
                batch.edge_attr,
                batch.batch,
            )
            y_score = logits.view(-1).float().cpu().numpy()


            all_preds.extend(y_score.tolist())
            all_probs.extend(y_score.tolist())

    # Build results DataFrame
    results = []
    for i, orig_idx in enumerate(valid_idx):
        results.append(
            {
                "molecule_id": molecule_ids[orig_idx],
                "smiles": smiles_list[orig_idx],
                "prediction": float(all_preds[i]),
            }
        )

    # Add failed molecules
    for idx in range(len(smiles_list)):
        if idx not in valid_idx:
            results.append(
                {
                    "molecule_id": molecule_ids[idx],
                    "smiles": smiles_list[idx],
                    "prediction": None,
                }
            )

    df = pd.DataFrame(results)

    return PredictionResult(
        predictions=df,
        model_type="GNN",
        model_path=str(model_path),
        run_dir=run_dir,
    )

Hyperparameter optimisation

HyperoptConfig dataclass

HyperoptConfig(model_type: Literal['RandomForest', 'SVR', 'GNN'], n_trials: int = 20, n_jobs: int = 1, sampler_seed: int = 42, rf_optimize_n_estimators: bool = False, rf_n_estimators_min: int = 50, rf_n_estimators_max: int = 500, rf_optimize_max_depth: bool = False, rf_max_depth_min: int = 5, rf_max_depth_max: int = 50, rf_optimize_min_samples_split: bool = False, rf_min_samples_split_min: int = 2, rf_min_samples_split_max: int = 20, svr_optimize_C: bool = False, svr_C_min: float = 0.1, svr_C_max: float = 100.0, svr_C_default: float = 10.0, svr_optimize_epsilon: bool = False, svr_epsilon_min: float = 0.01, svr_epsilon_max: float = 0.2, svr_epsilon_default: float = 0.1, svr_optimize_kernel: bool = False, svr_kernel_choices: list[str] = None, svr_kernel_default: str = 'rbf', architecture: Literal['gin', 'gcn', 'gat', 'sage', 'pna', 'transformer', 'tag', 'arma', 'cheb', 'supergat'] = 'gin', optimize_epochs: bool = False, epochs_min: int = 5, epochs_max: int = 50, epochs_default: int = 20, optimize_lr: bool = False, lr_min: float = 1e-05, lr_max: float = 0.01, optimize_batch_size: bool = False, batch_size_min: int = 16, batch_size_max: int = 256, batch_size_default: int = 64, optimize_embedding_dim: bool = False, embedding_dim_min: int = 32, embedding_dim_max: int = 512, embedding_dim_default: int = 128, optimize_hidden_dim: bool = False, hidden_dim_min: int = 32, hidden_dim_max: int = 512, hidden_dim_default: int = 128, optimize_num_layers: bool = False, num_layers_min: int = 1, num_layers_max: int = 5, num_layers_default: int = 3, optimize_dropout: bool = False, dropout_min: float = 0.0, dropout_max: float = 0.6, dropout_default: float = 0.1, optimize_pooling: bool = False, pooling_choices: list[str] = None, pooling_default: str = 'add', optimize_residual: bool = False, residual_default: bool = False, optimize_head_mlp_layers: bool = False, head_mlp_layers_min: int = 1, head_mlp_layers_max: int = 4, head_mlp_layers_default: int = 2, optimize_gin_conv_mlp_layers: bool = False, gin_conv_mlp_layers_min: int = 1, gin_conv_mlp_layers_max: int = 4, gin_conv_mlp_layers_default: int = 2, optimize_gin_train_eps: bool = False, gin_train_eps_default: bool = False, optimize_gin_eps: bool = False, gin_eps_min: float = 0.0, gin_eps_max: float = 1.0, gin_eps_default: float = 0.0, optimize_gat_heads: bool = False, gat_heads_min: int = 1, gat_heads_max: int = 8, gat_heads_default: int = 4, optimize_sage_aggr: bool = False, sage_aggr_choices: list[str] = None, sage_aggr_default: str = 'mean', optimize_transformer_heads: bool = False, transformer_heads_min: int = 1, transformer_heads_max: int = 8, transformer_heads_default: int = 4, optimize_tag_k: bool = False, tag_k_min: int = 1, tag_k_max: int = 5, tag_k_default: int = 2, optimize_arma_stacks: bool = False, arma_num_stacks_min: int = 1, arma_num_stacks_max: int = 3, arma_num_stacks_default: int = 1, optimize_arma_layers: bool = False, arma_num_layers_min: int = 1, arma_num_layers_max: int = 3, arma_num_layers_default: int = 1, optimize_cheb_k: bool = False, cheb_k_min: int = 1, cheb_k_max: int = 5, cheb_k_default: int = 2, optimize_supergat_heads: bool = False, supergat_heads_min: int = 1, supergat_heads_max: int = 8, supergat_heads_default: int = 4, optimize_supergat_attention_type: bool = False, supergat_attention_type_choices: list[str] = None, supergat_attention_type_default: str = 'MX', optimize_weight_decay: bool = False, weight_decay_min: float = 1e-06, weight_decay_max: float = 0.001, weight_decay_default: float = 0.0, device: str | None = None)

Configuration for hyperparameter optimization.

HyperoptResult dataclass

HyperoptResult(run_dir: Path, best_params: dict, best_value: float, best_trial_number: int, n_trials: int, study_path: str, best_params_path: str, strategy: Literal['holdout-val', 'cv'], cv_folds_used: Optional[int])

Result from hyperparameter optimization.

optimize_random_forest_wandb

optimize_random_forest_wandb(run_dir: Path, *, config: HyperoptConfig, project: str, entity: str | None = None, api_key: str | None = None, sweep_name: str | None = None, radius: int = 2, n_bits: int = 2048) -> HyperoptResult

Optimize RandomForest hyperparameters using a W&B Bayes sweep.

Uses: - Holdout validation if a val split exists. - Otherwise CV (KFold for regression).

Source code in src/dta_gnn/models/hyperopt.py
def optimize_random_forest_wandb(
    run_dir: Path,
    *,
    config: HyperoptConfig,
    project: str,
    entity: str | None = None,
    api_key: str | None = None,
    sweep_name: str | None = None,
    radius: int = 2,
    n_bits: int = 2048,
) -> HyperoptResult:
    """Optimize RandomForest hyperparameters using a W&B Bayes sweep.

    Uses:
    - Holdout validation if a val split exists.
    - Otherwise CV (KFold for regression).
    """

    wandb = _require_wandb()

    run_dir = Path(run_dir).resolve()
    dataset_path = run_dir / "dataset.csv"
    compounds_path = run_dir / "compounds.csv"
    metadata_path = run_dir / "metadata.json"

    if not dataset_path.exists() or not compounds_path.exists():
        raise FileNotFoundError(f"Expected {dataset_path} and {compounds_path}")

    df = pd.read_csv(dataset_path)
    compounds = pd.read_csv(compounds_path)

    if "molecule_chembl_id" not in df.columns or "label" not in df.columns:
        raise ValueError("dataset.csv must contain 'molecule_chembl_id' and 'label'.")
    if (
        "molecule_chembl_id" not in compounds.columns
        or "smiles" not in compounds.columns
    ):
        raise ValueError(
            "compounds.csv must contain 'molecule_chembl_id' and 'smiles'."
        )

    meta = None
    try:
        if metadata_path.exists():
            meta = json.loads(metadata_path.read_text())
    except Exception:
        meta = None

    task_type = _infer_task_type_from_metadata_or_labels(df, meta)

    # Compute fingerprints once.


    df_comp = (
        compounds[["molecule_chembl_id", "smiles"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
    df_feat = calculate_morgan_fingerprints(
        df_comp,
        smiles_col="smiles",
        radius=int(radius),
        n_bits=int(n_bits),
        out_col="morgan_fingerprint",
        drop_failures=True,
    )
    df_feat = (
        df_feat[["molecule_chembl_id", "morgan_fingerprint"]].dropna().drop_duplicates()
    )
    fp_map = dict(
        zip(
            df_feat["molecule_chembl_id"].astype(str),
            df_feat["morgan_fingerprint"].astype(str),
        )
    )

    df2 = df.copy()
    df2["_fp"] = df2["molecule_chembl_id"].astype(str).map(fp_map)
    df2 = df2.dropna(subset=["_fp"]).reset_index(drop=True)
    if df2.empty:
        raise ValueError("No rows left after joining fingerprints to dataset.")

    X = _bitstrings_to_numpy(df2["_fp"].astype(str).tolist(), n_bits=int(n_bits))
    y = df2["label"].to_numpy()

    # Exclude test split for tuning when present.
    if "split" in df2.columns:
        base_mask = df2["split"].astype(str).ne("test").to_numpy()
        has_val = bool((df2.loc[base_mask, "split"].astype(str) == "val").any())
    else:
        base_mask = np.ones(len(df2), dtype=bool)
        has_val = False

    project = (project or "").strip() or "dta_gnn"
    entity = (entity or "").strip() or None
    if api_key and str(api_key).strip():
        wandb.login(key=str(api_key).strip(), relogin=True)

    # Build sweep parameter space only from enabled knobs.
    parameters: dict[str, dict] = {}
    if config.rf_optimize_n_estimators:
        parameters["n_estimators"] = {
            "distribution": "int_uniform",
            "min": int(config.rf_n_estimators_min),
            "max": int(config.rf_n_estimators_max),
        }
    if config.rf_optimize_max_depth:
        parameters["max_depth"] = {
            "distribution": "int_uniform",
            "min": int(config.rf_max_depth_min),
            "max": int(config.rf_max_depth_max),
        }
    if config.rf_optimize_min_samples_split:
        parameters["min_samples_split"] = {
            "distribution": "int_uniform",
            "min": int(config.rf_min_samples_split_min),
            "max": int(config.rf_min_samples_split_max),
        }

    if not parameters:
        raise ValueError(
            "No parameters selected for optimization. "
            "Enable at least one 'Optimize ...' checkbox before running a sweep."
        )

    sweep_config: dict[str, object] = {
        "name": sweep_name or f"dta_gnn_rf_{task_type}",
        "method": "bayes",
        "metric": {"name": "val_score", "goal": "maximize"},
        "parameters": parameters,
    }

    sweep_id = wandb.sweep(sweep=sweep_config, project=project, entity=entity)

    best_score = -math.inf
    best_params: dict[str, object] = {}
    best_trial_number = -1
    trial_counter = {"i": 0}

    def _trial_fn():
        nonlocal best_score, best_params, best_trial_number

        run = wandb.init(project=project, entity=entity, config={})
        trial_idx = int(trial_counter["i"])
        trial_counter["i"] = trial_idx + 1

        sampled = dict(getattr(wandb, "config", {}) or {})

        n_estimators = int(sampled.get("n_estimators", 500))
        max_depth = sampled.get("max_depth", None)
        max_depth = int(max_depth) if max_depth is not None else None
        min_samples_split = int(sampled.get("min_samples_split", 2))

        score: float
        extra_logs: dict[str, object] = {}

        model = RandomForestRegressor(
            n_estimators=n_estimators,
            random_state=42,
            n_jobs=-1,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
        )

        if has_val and "split" in df2.columns:
            train_mask = base_mask & (
                df2["split"].astype(str).eq("train").to_numpy()
            )
            val_mask = base_mask & (df2["split"].astype(str).eq("val").to_numpy())
            if int(train_mask.sum()) < 2 or int(val_mask.sum()) < 1:
                raise ValueError(
                    "Validation split exists but is too small for RF sweep."
                )

            model.fit(X[train_mask], y[train_mask].astype(float))
            y_true = y[val_mask].astype(float)
            y_pred = model.predict(X[val_mask]).astype(float)
            rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
            extra_logs["val/rmse"] = rmse
            if y_true.size >= 2:
                r2 = float(r2_score(y_true, y_pred))
                extra_logs["val/r2"] = r2
                score = r2
            else:
                score = -rmse
        else:
            y_base = y[base_mask].astype(float)
            X_base = X[base_mask]
            # Use simple 5-fold CV (non-stratified).
            n_splits = min(5, int(len(y_base)))
            n_splits = max(n_splits, 2)
            cv = KFold(n_splits=n_splits, shuffle=True, random_state=42)
            scores = cross_val_score(model, X_base, y_base, cv=cv, scoring="r2")
            score = float(np.mean(scores)) if len(scores) else 0.0
            extra_logs["cv/r2"] = score

        wandb.log({"val_score": float(score), **extra_logs})
        run.summary["val_score"] = float(score)
        run.summary["task_type"] = task_type
        run.finish()

        if float(score) > float(best_score):
            best_score = float(score)
            best_trial_number = int(trial_idx)
            best_params = {
                "n_estimators": n_estimators,
                "max_depth": max_depth,
                "min_samples_split": min_samples_split,
                "radius": int(radius),
                "n_bits": int(n_bits),
                "task_type": task_type,
            }

    wandb.agent(sweep_id, function=_trial_fn, count=int(config.n_trials))

    best_params_path = str(run_dir / f"hyperopt_best_params_wandb_rf_{task_type}.json")
    with open(best_params_path, "w") as f:
        json.dump(best_params, f, indent=2)

    return HyperoptResult(
        run_dir=run_dir,
        best_params=best_params,
        best_value=float(best_score) if best_score != -math.inf else 0.0,
        best_trial_number=int(best_trial_number),
        n_trials=int(config.n_trials),
        study_path=str(sweep_id),
        best_params_path=str(best_params_path),
        strategy="holdout-val" if has_val else "cv",
        cv_folds_used=None,
    )

optimize_svr_wandb

optimize_svr_wandb(run_dir: Path, *, config: HyperoptConfig, project: str, entity: str | None = None, api_key: str | None = None, sweep_name: str | None = None, radius: int = 2, n_bits: int = 2048) -> HyperoptResult

Optimize SVR hyperparameters using a W&B Bayes sweep.

Notes: - SVR sweep is intended for regression runs (DTA). - Uses holdout validation if a val split exists; otherwise uses KFold CV. - Logs a single scalar metric val_score (maximize).

Source code in src/dta_gnn/models/hyperopt.py
def optimize_svr_wandb(
    run_dir: Path,
    *,
    config: HyperoptConfig,
    project: str,
    entity: str | None = None,
    api_key: str | None = None,
    sweep_name: str | None = None,
    radius: int = 2,
    n_bits: int = 2048,
) -> HyperoptResult:
    """Optimize SVR hyperparameters using a W&B Bayes sweep.

    Notes:
    - SVR sweep is intended for regression runs (DTA).
    - Uses holdout validation if a val split exists; otherwise uses KFold CV.
    - Logs a single scalar metric `val_score` (maximize).
    """

    wandb = _require_wandb()

    run_dir = Path(run_dir).resolve()
    dataset_path = run_dir / "dataset.csv"
    compounds_path = run_dir / "compounds.csv"
    metadata_path = run_dir / "metadata.json"

    if not dataset_path.exists() or not compounds_path.exists():
        raise FileNotFoundError(f"Expected {dataset_path} and {compounds_path}")

    df = pd.read_csv(dataset_path)
    compounds = pd.read_csv(compounds_path)

    if "molecule_chembl_id" not in df.columns or "label" not in df.columns:
        raise ValueError("dataset.csv must contain 'molecule_chembl_id' and 'label'.")
    if (
        "molecule_chembl_id" not in compounds.columns
        or "smiles" not in compounds.columns
    ):
        raise ValueError(
            "compounds.csv must contain 'molecule_chembl_id' and 'smiles'."
        )

    meta = None
    try:
        if metadata_path.exists():
            meta = json.loads(metadata_path.read_text())
    except Exception:
        meta = None

    task_type = _infer_task_type_from_metadata_or_labels(df, meta)
    if task_type != "regression":
        raise ValueError(
            "SVR sweeps are supported for regression runs only. "
            "Build a regression (DTA) dataset or choose a different model for HPO."
        )

    # Compute fingerprints once.

    df_comp = (
        compounds[["molecule_chembl_id", "smiles"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
    df_feat = calculate_morgan_fingerprints(
        df_comp,
        smiles_col="smiles",
        radius=int(radius),
        n_bits=int(n_bits),
        out_col="morgan_fingerprint",
        drop_failures=True,
    )
    df_feat = (
        df_feat[["molecule_chembl_id", "morgan_fingerprint"]].dropna().drop_duplicates()
    )
    fp_map = dict(
        zip(
            df_feat["molecule_chembl_id"].astype(str),
            df_feat["morgan_fingerprint"].astype(str),
        )
    )

    df2 = df.copy()
    df2["_fp"] = df2["molecule_chembl_id"].astype(str).map(fp_map)
    df2 = df2.dropna(subset=["_fp"]).reset_index(drop=True)
    if df2.empty:
        raise ValueError("No rows left after joining fingerprints to dataset.")

    X = _bitstrings_to_numpy(df2["_fp"].astype(str).tolist(), n_bits=int(n_bits))
    y = df2["label"].astype(float).to_numpy()

    # Exclude test split for tuning when present.
    if "split" in df2.columns:
        base_mask = df2["split"].astype(str).ne("test").to_numpy()
        has_val = bool((df2.loc[base_mask, "split"].astype(str) == "val").any())
    else:
        base_mask = np.ones(len(df2), dtype=bool)
        has_val = False

    project = (project or "").strip() or "dta_gnn"
    entity = (entity or "").strip() or None
    if api_key and str(api_key).strip():
        wandb.login(key=str(api_key).strip(), relogin=True)

    # Build sweep parameter space only from enabled knobs.
    parameters: dict[str, dict] = {}

    if config.svr_optimize_C:
        parameters["C"] = {
            "distribution": "log_uniform_values",
            "min": float(config.svr_C_min),
            "max": float(config.svr_C_max),
        }
    if config.svr_optimize_epsilon:
        parameters["epsilon"] = {
            "distribution": "log_uniform_values",
            "min": float(config.svr_epsilon_min),
            "max": float(config.svr_epsilon_max),
        }

    kernel_choices = getattr(config, "svr_kernel_choices", None) or ["rbf", "linear"]
    kernel_choices = [
        str(k).strip().lower()
        for k in kernel_choices
        if str(k).strip().lower() in {"rbf", "linear"}
    ]
    if not kernel_choices:
        kernel_choices = ["rbf", "linear"]
    if config.svr_optimize_kernel:
        parameters["kernel"] = {"values": kernel_choices}

    if not parameters:
        raise ValueError(
            "No parameters selected for optimization. "
            "Enable at least one 'Optimize ...' checkbox before running a sweep."
        )

    sweep_config: dict[str, object] = {
        "name": sweep_name or "dta_gnn_svr",
        "method": "bayes",
        "metric": {"name": "val_score", "goal": "maximize"},
        "parameters": parameters,
    }

    sweep_id = wandb.sweep(sweep=sweep_config, project=project, entity=entity)

    best_score = -math.inf
    best_params: dict[str, object] = {}
    best_trial_number = -1
    trial_counter = {"i": 0}

    def _trial_fn():
        nonlocal best_score, best_params, best_trial_number

        run = wandb.init(project=project, entity=entity, config={})
        trial_idx = int(trial_counter["i"])
        trial_counter["i"] = trial_idx + 1

        sampled = dict(getattr(wandb, "config", {}) or {})



        C = float(sampled.get("C", float(getattr(config, "svr_C_default", 10.0))))
        epsilon = float(
            sampled.get("epsilon", float(getattr(config, "svr_epsilon_default", 0.1)))
        )

        k = (
            str(
                sampled.get("kernel", str(getattr(config, "svr_kernel_default", "rbf")))
            )
            .strip()
            .lower()
        )
        if k not in {"rbf", "linear"}:
            k = "rbf"

        model = SVR(kernel=k, C=C, epsilon=epsilon)

        score: float
        extra_logs: dict[str, object] = {}

        if has_val and "split" in df2.columns:
            train_mask = base_mask & (df2["split"].astype(str).eq("train").to_numpy())
            val_mask = base_mask & (df2["split"].astype(str).eq("val").to_numpy())
            if int(train_mask.sum()) < 2 or int(val_mask.sum()) < 1:
                raise ValueError(
                    "Validation split exists but is too small for SVR sweep."
                )

            model.fit(X[train_mask], y[train_mask])
            y_true = y[val_mask]
            y_pred = model.predict(X[val_mask]).astype(float)
            rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
            extra_logs["val/rmse"] = rmse
            if y_true.size >= 2:
                r2 = float(r2_score(y_true, y_pred))
                extra_logs["val/r2"] = r2
                score = r2
            else:
                score = -rmse
        else:
            y_base = y[base_mask]
            X_base = X[base_mask]
            n_splits = min(5, int(len(y_base)))
            n_splits = max(n_splits, 2)
            cv = KFold(n_splits=n_splits, shuffle=True, random_state=42)
            scores = cross_val_score(model, X_base, y_base, cv=cv, scoring="r2")
            score = float(np.mean(scores)) if len(scores) else 0.0
            extra_logs["cv/r2"] = score
            extra_logs["cv/folds"] = int(n_splits)

        wandb.log({"val_score": float(score), **extra_logs})
        run.summary["val_score"] = float(score)
        run.summary["task_type"] = task_type
        run.finish()

        if float(score) > float(best_score):
            best_score = float(score)
            best_trial_number = int(trial_idx)
            best_params = {
                "C": float(C),
                "epsilon": float(epsilon),
                "kernel": k,
                "radius": int(radius),
                "n_bits": int(n_bits),
                "task_type": task_type,
            }

    wandb.agent(sweep_id, function=_trial_fn, count=int(config.n_trials))

    best_params_path = str(run_dir / "hyperopt_best_params_wandb_svr.json")
    with open(best_params_path, "w") as f:
        json.dump(best_params, f, indent=2)

    return HyperoptResult(
        run_dir=run_dir,
        best_params=best_params,
        best_value=float(best_score) if best_score != -math.inf else 0.0,
        best_trial_number=int(best_trial_number),
        n_trials=int(config.n_trials),
        study_path=str(sweep_id),
        best_params_path=str(best_params_path),
        strategy="holdout-val" if has_val else "cv",
        cv_folds_used=None,
    )

optimize_gnn_wandb

optimize_gnn_wandb(run_dir: Path, *, config: HyperoptConfig, project: str, entity: str | None = None, api_key: str | None = None, sweep_name: str | None = None) -> HyperoptResult

Optimize GNN hyperparameters using a W&B Bayes sweep.

Notes: - Uses the existing train/val split (requires a non-empty val split). - Runs each trial in an isolated subdirectory so artifacts don't overwrite. - Logs a single scalar metric val_score (maximize).

Source code in src/dta_gnn/models/hyperopt.py
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
def optimize_gnn_wandb(
    run_dir: Path,
    *,
    config: HyperoptConfig,
    project: str,
    entity: str | None = None,
    api_key: str | None = None,
    sweep_name: str | None = None,
) -> HyperoptResult:
    """Optimize GNN hyperparameters using a W&B Bayes sweep.

    Notes:
    - Uses the existing train/val split (requires a non-empty val split).
    - Runs each trial in an isolated subdirectory so artifacts don't overwrite.
    - Logs a single scalar metric `val_score` (maximize).
    """

    wandb = _require_wandb()

    run_dir = Path(run_dir).resolve()
    dataset_path = run_dir / "dataset.csv"
    compounds_path = run_dir / "compounds.csv"
    metadata_path = run_dir / "metadata.json"

    if not dataset_path.exists() or not compounds_path.exists():
        raise FileNotFoundError(f"Expected {dataset_path} and {compounds_path}")

    df = pd.read_csv(dataset_path)
    if "split" not in df.columns or not bool((df["split"] == "val").any()):
        raise ValueError(
            "W&B sweeps currently require an explicit validation split. "
            "Rebuild the dataset with val_size > 0.0."
        )

    arch = str(getattr(config, "architecture", "gin") or "gin").strip().lower()
    _ALL_ARCHS = {"gin", "gcn", "gat", "sage", "pna", "transformer", "tag", "arma", "cheb", "supergat"}
    if arch not in _ALL_ARCHS:
        arch = "gin"

    project = (project or "").strip() or "dta_gnn"
    entity = (entity or "").strip() or None
    if api_key and str(api_key).strip():
        wandb.login(key=str(api_key).strip(), relogin=True)

    # Build sweep parameter space only from enabled knobs.
    parameters: dict[str, dict] = {}

    def _int_values(min_v: int, max_v: int, step: int = 1) -> list[int]:
        if step <= 1:
            return list(range(int(min_v), int(max_v) + 1))
        return list(range(int(min_v), int(max_v) + 1, int(step)))

    if config.optimize_epochs:
        parameters["epochs"] = {
            "distribution": "int_uniform",
            "min": int(config.epochs_min),
            "max": int(config.epochs_max),
        }
    if config.optimize_batch_size:
        parameters["batch_size"] = {
            "distribution": "int_uniform",
            "min": int(config.batch_size_min),
            "max": int(config.batch_size_max),
        }
    if config.optimize_lr:
        parameters["lr"] = {
            "distribution": "log_uniform_values",
            "min": float(config.lr_min),
            "max": float(config.lr_max),
        }
    if config.optimize_embedding_dim:
        parameters["embedding_dim"] = {
            "values": _int_values(
                int(config.embedding_dim_min),
                int(config.embedding_dim_max),
                step=16,
            )
        }
    if config.optimize_hidden_dim:
        parameters["hidden_dim"] = {
            "values": _int_values(
                int(config.hidden_dim_min), int(config.hidden_dim_max), step=16
            )
        }

    if config.optimize_num_layers:
        parameters["num_layers"] = {
            "distribution": "int_uniform",
            "min": int(config.num_layers_min),
            "max": int(config.num_layers_max),
        }

    if config.optimize_dropout:
        parameters["dropout"] = {
            "distribution": "uniform",
            "min": float(config.dropout_min),
            "max": float(config.dropout_max),
        }

    pooling_choices = getattr(config, "pooling_choices", None) or [
        "add",
        "mean",
        "max",
        "attention",
    ]
    pooling_choices = [
        str(x) for x in pooling_choices if str(x) in {"add", "mean", "max", "attention"}
    ]
    if not pooling_choices:
        pooling_choices = ["add", "mean", "max", "attention"]
    if config.optimize_pooling:
        parameters["pooling"] = {"values": pooling_choices}

    if config.optimize_residual:
        parameters["residual"] = {"values": [False, True]}

    if config.optimize_head_mlp_layers:
        parameters["head_mlp_layers"] = {
            "distribution": "int_uniform",
            "min": int(getattr(config, "head_mlp_layers_min", 1)),
            "max": int(getattr(config, "head_mlp_layers_max", 4)),
        }

    if arch == "gin":
        if config.optimize_gin_conv_mlp_layers:
            parameters["gin_conv_mlp_layers"] = {
                "distribution": "int_uniform",
                "min": int(getattr(config, "gin_conv_mlp_layers_min", 1)),
                "max": int(getattr(config, "gin_conv_mlp_layers_max", 4)),
            }
        if config.optimize_gin_train_eps:
            parameters["gin_train_eps"] = {"values": [False, True]}
        if config.optimize_gin_eps:
            parameters["gin_eps"] = {
                "distribution": "uniform",
                "min": float(getattr(config, "gin_eps_min", 0.0)),
                "max": float(getattr(config, "gin_eps_max", 1.0)),
            }

    if arch == "gat" and config.optimize_gat_heads:
        parameters["gat_heads"] = {
            "distribution": "int_uniform",
            "min": int(getattr(config, "gat_heads_min", 1)),
            "max": int(getattr(config, "gat_heads_max", 8)),
        }

    if arch == "sage":
        sage_aggr_choices = getattr(config, "sage_aggr_choices", None) or [
            "mean",
            "max",
            "add",
        ]
        sage_aggr_choices = [
            str(x) for x in sage_aggr_choices if str(x) in {"mean", "max", "add"}
        ]
        if not sage_aggr_choices:
            sage_aggr_choices = ["mean", "max", "add"]
        if config.optimize_sage_aggr:
            parameters["sage_aggr"] = {"values": sage_aggr_choices}

    if arch == "transformer" and config.optimize_transformer_heads:
        parameters["transformer_heads"] = {
            "distribution": "int_uniform",
            "min": int(getattr(config, "transformer_heads_min", 1)),
            "max": int(getattr(config, "transformer_heads_max", 8)),
        }

    if arch == "tag" and config.optimize_tag_k:
        parameters["tag_k"] = {
            "distribution": "int_uniform",
            "min": int(getattr(config, "tag_k_min", 1)),
            "max": int(getattr(config, "tag_k_max", 5)),
        }

    if arch == "arma":
        if config.optimize_arma_stacks:
            parameters["arma_num_stacks"] = {
                "distribution": "int_uniform",
                "min": int(getattr(config, "arma_num_stacks_min", 1)),
                "max": int(getattr(config, "arma_num_stacks_max", 3)),
            }
        if config.optimize_arma_layers:
            parameters["arma_num_layers"] = {
                "distribution": "int_uniform",
                "min": int(getattr(config, "arma_num_layers_min", 1)),
                "max": int(getattr(config, "arma_num_layers_max", 3)),
            }

    if arch == "cheb" and config.optimize_cheb_k:
        parameters["cheb_k"] = {
            "distribution": "int_uniform",
            "min": int(getattr(config, "cheb_k_min", 1)),
            "max": int(getattr(config, "cheb_k_max", 5)),
        }

    if arch == "supergat":
        if config.optimize_supergat_heads:
            parameters["supergat_heads"] = {
                "distribution": "int_uniform",
                "min": int(getattr(config, "supergat_heads_min", 1)),
                "max": int(getattr(config, "supergat_heads_max", 8)),
            }
        supergat_attention_choices = getattr(config, "supergat_attention_type_choices", None) or [
            "MX",
            "SD",
        ]
        supergat_attention_choices = [
            str(x) for x in supergat_attention_choices if str(x) in {"MX", "SD"}
        ]
        if not supergat_attention_choices:
            supergat_attention_choices = ["MX", "SD"]
        if config.optimize_supergat_attention_type:
            parameters["supergat_attention_type"] = {"values": supergat_attention_choices}

    if config.optimize_weight_decay:
        parameters["weight_decay"] = {
            "distribution": "log_uniform_values",
            "min": float(config.weight_decay_min),
            "max": float(config.weight_decay_max),
        }

    if not parameters:
        raise ValueError(
            "No parameters selected for optimization. "
            "Enable at least one 'Optimize ...' checkbox before running a sweep."
        )

    sweep_config: dict[str, object] = {
        "name": sweep_name or f"dta_gnn_gnn_{arch}",
        "method": "bayes",
        "metric": {"name": "val_score", "goal": "maximize"},
        "parameters": parameters,
    }

    sweep_id = wandb.sweep(sweep=sweep_config, project=project, entity=entity)

    best_score = -math.inf
    best_params: dict[str, object] = {}
    best_trial_number = -1
    trial_counter = {"i": 0}

    from dta_gnn.models.gnn import GnnTrainConfig, train_gnn_on_run

    def _trial_fn():
        nonlocal best_score, best_params, best_trial_number

        run = wandb.init(project=project, entity=entity, config={"architecture": arch})
        trial_idx = int(trial_counter["i"])
        trial_counter["i"] = trial_idx + 1

        logger.info(f"[Trial {trial_idx}] Starting GNN training...")
        logger.info(f"[Trial {trial_idx}] Run ID: {run.id}")

        # Pull optimized params from wandb.config; fill the rest from defaults.
        sampled = dict(getattr(wandb, "config", {}) or {})

        wandb.config.update({"trial_number": trial_idx})

        # Architecture-specific parameters
        gin_conv_mlp_layers = int(
            sampled.get(
                "gin_conv_mlp_layers",
                int(getattr(config, "gin_conv_mlp_layers_default", 2)),
            )
        )
        gin_train_eps = bool(
            sampled.get(
                "gin_train_eps",
                bool(getattr(config, "gin_train_eps_default", False)),
            )
        )
        gin_eps = float(
            sampled.get(
                "gin_eps",
                float(getattr(config, "gin_eps_default", 0.0)),
            )
        )
        gat_heads = int(
            sampled.get(
                "gat_heads",
                int(getattr(config, "gat_heads_default", 4)),
            )
        )
        sage_aggr = str(
            sampled.get(
                "sage_aggr",
                str(getattr(config, "sage_aggr_default", "mean")),
            )
        )
        transformer_heads = int(
            sampled.get(
                "transformer_heads",
                int(getattr(config, "transformer_heads_default", 4)),
            )
        )
        tag_k = int(
            sampled.get(
                "tag_k",
                int(getattr(config, "tag_k_default", 2)),
            )
        )
        arma_num_stacks = int(
            sampled.get(
                "arma_num_stacks",
                int(getattr(config, "arma_num_stacks_default", 1)),
            )
        )
        arma_num_layers = int(
            sampled.get(
                "arma_num_layers",
                int(getattr(config, "arma_num_layers_default", 1)),
            )
        )
        cheb_k = int(
            sampled.get(
                "cheb_k",
                int(getattr(config, "cheb_k_default", 2)),
            )
        )
        supergat_heads = int(
            sampled.get(
                "supergat_heads",
                int(getattr(config, "supergat_heads_default", 4)),
            )
        )
        supergat_attention_type = str(
            sampled.get(
                "supergat_attention_type",
                str(getattr(config, "supergat_attention_type_default", "MX")),
            )
        )

        cfg = GnnTrainConfig(
            architecture=arch,
            epochs=int(sampled.get("epochs", int(getattr(config, "epochs_default", 20)))),
            batch_size=int(sampled.get("batch_size", int(getattr(config, "batch_size_default", 64)))),
            lr=float(sampled.get("lr", 1e-3)),
            weight_decay=float(sampled.get("weight_decay", float(getattr(config, "weight_decay_default", 0.0)))),
            embedding_dim=int(sampled.get("embedding_dim", int(getattr(config, "embedding_dim_default", 128)))),
            device=getattr(config, "device", None),
            hidden_dim=int(sampled.get("hidden_dim", int(getattr(config, "hidden_dim_default", 128)))),
            dropout=float(
                sampled.get(
                    "dropout", float(getattr(config, "dropout_default", 0.1))
                )
            ),
            pooling=str(
                sampled.get(
                    "pooling", str(getattr(config, "pooling_default", "add"))
                )
            ),
            residual=bool(
                sampled.get(
                    "residual", bool(getattr(config, "residual_default", False))
                )
            ),
            head_mlp_layers=int(
                sampled.get(
                    "head_mlp_layers",
                    int(getattr(config, "head_mlp_layers_default", 2)),
                )
            ),
            gin_conv_mlp_layers=gin_conv_mlp_layers,
            gin_train_eps=gin_train_eps,
            gin_eps=gin_eps,
            gat_heads=gat_heads,
            sage_aggr=sage_aggr,
            transformer_heads=transformer_heads,
            tag_k=tag_k,
            arma_num_stacks=arma_num_stacks,
            arma_num_layers=arma_num_layers,
            cheb_k=cheb_k,
            supergat_heads=supergat_heads,
            supergat_attention_type=supergat_attention_type,
            num_layers=int(
                sampled.get(
                    "num_layers",
                    int(getattr(config, "num_layers_default", 3)),
                )
            ),
        )

        trial_dir = run_dir / f"_wandb_{arch}_{trial_idx:03d}_{run.id}"
        trial_dir.mkdir(exist_ok=True)
        shutil.copy2(dataset_path, trial_dir / "dataset.csv")
        shutil.copy2(compounds_path, trial_dir / "compounds.csv")
        if metadata_path.exists():
            shutil.copy2(metadata_path, trial_dir / "metadata.json")

        logger.info(f"[Trial {trial_idx}] Training GNN with config: epochs={cfg.epochs}, batch_size={cfg.batch_size}, lr={cfg.lr:.6f}, embedding_dim={cfg.embedding_dim}")
        res = train_gnn_on_run(trial_dir, config=cfg, wandb_run=run)
        score = _score_from_gnn_metrics(res.task_type, res.metrics)
        logger.info(f"[Trial {trial_idx}] Training complete. Score: {score:.4f}")

        # Log all metrics from all splits
        all_metrics = {}
        splits_metrics = (res.metrics or {}).get("splits", {}) or {}

        # Log validation metrics
        val_metrics = splits_metrics.get("val", {}) or {}
        all_metrics.update({
            "val_score": float(score),
            "val/roc_auc": val_metrics.get("roc_auc"),
            "val/accuracy": val_metrics.get("accuracy"),
            "val/rmse": val_metrics.get("rmse"),
            "val/mae": val_metrics.get("mae"),
            "val/r2": val_metrics.get("r2"),
        })

        # Log training metrics
        train_metrics = splits_metrics.get("train", {}) or {}
        all_metrics.update({
            "train/roc_auc": train_metrics.get("roc_auc"),
            "train/accuracy": train_metrics.get("accuracy"),
            "train/rmse": train_metrics.get("rmse"),
            "train/mae": train_metrics.get("mae"),
            "train/r2": train_metrics.get("r2"),
        })

        # Log test metrics if available
        test_metrics = splits_metrics.get("test", {}) or {}
        if test_metrics:
            all_metrics.update({
                "test/roc_auc": test_metrics.get("roc_auc"),
                "test/accuracy": test_metrics.get("accuracy"),
                "test/rmse": test_metrics.get("rmse"),
                "test/mae": test_metrics.get("mae"),
                "test/r2": test_metrics.get("r2"),
            })

        # Remove None values before logging
        all_metrics = {k: v for k, v in all_metrics.items() if v is not None}
        wandb.log(all_metrics)

        run.summary["val_score"] = float(score)
        run.summary["run_dir"] = str(trial_dir)
        run.summary["architecture"] = arch
        run.summary["task_type"] = res.task_type
        run.summary["trial_number"] = trial_idx
        run.summary["is_best"] = float(score) > float(best_score)
        run.finish()

        if float(score) > float(best_score):
            best_score = float(score)
            best_trial_number = int(trial_idx)
            # Return the *sampled* params plus fixed architecture for reproducibility.
            best_params = {"architecture": arch, **{k: v for k, v in sampled.items()}}
            logger.info(f"[Trial {trial_idx}] New best score: {best_score:.4f} (trial #{best_trial_number})")
        else:
            logger.info(f"[Trial {trial_idx}] Score: {score:.4f} (best so far: {best_score:.4f})")

    logger.info(f"\n{'='*60}")
    logger.info(f"Starting W&B sweep with {config.n_trials} trials")
    logger.info(f"Sweep ID: {sweep_id}")
    logger.info(f"Architecture: {arch}")
    logger.info(f"{'='*60}\n")

    wandb.agent(sweep_id, function=_trial_fn, count=int(config.n_trials))

    best_params_path = str(run_dir / f"hyperopt_best_params_wandb_{arch}.json")
    with open(best_params_path, "w") as f:
        json.dump(best_params, f, indent=2)

    logger.info(f"\n{'='*60}")
    logger.info("Sweep completed!")
    logger.info(f"Best score: {best_score:.4f} (trial #{best_trial_number})")
    logger.info(f"Best params saved to: {best_params_path}")
    logger.info(f"{'='*60}\n")

    return HyperoptResult(
        run_dir=run_dir,
        best_params=best_params,
        best_value=float(best_score) if best_score != -math.inf else 0.0,
        best_trial_number=int(best_trial_number),
        n_trials=int(config.n_trials),
        study_path=str(sweep_id),
        best_params_path=str(best_params_path),
        strategy="holdout-val",
        cv_folds_used=None,
    )

optimize_random_forest and optimize_gnn are aliases that resolve to optimize_random_forest_wandb and optimize_gnn_wandb respectively.

Model utilities

list_available_models

list_available_models(run_dir: Path | None = None) -> dict[str, list[str]]

List all available trained models in the run directory.

Parameters:

Name Type Description Default
run_dir Path | None

Path to run directory. If None, attempts to resolve current run directory.

None

Returns:

Type Description
dict[str, list[str]]

Dictionary with keys: 'rf', 'svr', 'gnn'

dict[str, list[str]]

Each value is a list of model identifiers (for GNN: architecture names)

Source code in src/dta_gnn/models/utils.py
def list_available_models(run_dir: Path | None = None) -> dict[str, list[str]]:
    """List all available trained models in the run directory.

    Args:
        run_dir: Path to run directory. If None, attempts to resolve current run directory.

    Returns:
        Dictionary with keys: 'rf', 'svr', 'gnn'
        Each value is a list of model identifiers (for GNN: architecture names)
    """
    if run_dir is None:
        try:
            run_dir = resolve_current_run_dir()
        except FileNotFoundError:
            return {"rf": [], "svr": [], "gnn": []}

    if not run_dir or not run_dir.exists():
        return {"rf": [], "svr": [], "gnn": []}

    models = {"rf": [], "svr": [], "gnn": []}

    # Check for RandomForest model
    rf_model = run_dir / "model_rf.pkl"
    if rf_model.exists():
        models["rf"].append("RandomForest")

    # Check for SVR model
    svr_model = run_dir / "model_svr.pkl"
    if svr_model.exists():
        models["svr"].append("SVR")

    # Check for GNN models (model_gnn_<architecture>.pt)
    gnn_architectures = [
        "gin",
        "gcn",
        "gat",
        "sage",
        "pna",
        "transformer",
        "tag",
        "arma",
        "cheb",
        "supergat",
    ]
    for arch in gnn_architectures:
        gnn_model = run_dir / f"model_gnn_{arch}.pt"
        config_file = run_dir / f"encoder_{arch}_config.json"

        # Correction in previous write attempt: variable name typo fix
        if gnn_model.exists() and config_file.exists():
            # Format: "GNN (GIN)", "GNN (Transformer)", etc.
            arch_display = (
                arch.upper() if arch in ["gin", "gat", "pna"] else arch.capitalize()
            )
            models["gnn"].append(f"GNN ({arch_display})")

    return models

Audits (dta_gnn.audits)

audit_scaffold_leakage

audit_scaffold_leakage(train_df: DataFrame, test_df: DataFrame, smiles_col: str = 'smiles') -> Dict[str, Any]

Check if scaffolds from test set appear in train set.

Source code in src/dta_gnn/audits/leakage.py
def audit_scaffold_leakage(
    train_df: pd.DataFrame, test_df: pd.DataFrame, smiles_col: str = "smiles"
) -> Dict[str, Any]:
    """
    Check if scaffolds from test set appear in train set.
    """

    def get_scaffolds(df):
        scaffs = set()
        for s in df[smiles_col].dropna():
            try:
                scaffs.add(MurckoScaffold.MurckoScaffoldSmiles(s))
            except Exception as e:
                logger.debug("Skipping SMILES {!r} during scaffold extraction: {}", s, e)
        return scaffs

    train_scaffolds = get_scaffolds(train_df)
    test_scaffolds = get_scaffolds(test_df)

    overlap = train_scaffolds.intersection(test_scaffolds)
    return {
        "train_scaffolds": len(train_scaffolds),
        "test_scaffolds": len(test_scaffolds),
        "overlap_count": len(overlap),
        "leakage_ratio": len(overlap) / len(test_scaffolds) if test_scaffolds else 0.0,
    }

audit_target_leakage

audit_target_leakage(train_df: DataFrame, test_df: DataFrame, target_col: str = 'target_chembl_id') -> Dict[str, Any]

Check exact target ID overlap.

Source code in src/dta_gnn/audits/leakage.py
def audit_target_leakage(
    train_df: pd.DataFrame, test_df: pd.DataFrame, target_col: str = "target_chembl_id"
) -> Dict[str, Any]:
    """
    Check exact target ID overlap.
    """
    train_targets = set(train_df[target_col].dropna())
    test_targets = set(test_df[target_col].dropna())

    overlap = train_targets.intersection(test_targets)
    return {
        "train_targets": len(train_targets),
        "test_targets": len(test_targets),
        "overlap_count": len(overlap),
        "leakage_ratio": len(overlap) / len(test_targets) if test_targets else 0.0,
    }

Exporters (dta_gnn.exporters)

collect_artifacts

collect_artifacts(*, run_dir: str | None, dataset_path: str | None = None, targets_path: str | None = None, compounds_path: str | None = None) -> dict[str, str | None]

Collect artifact file paths from a run directory.

Parameters:

Name Type Description Default
run_dir str | None

Path to the run directory

required
dataset_path str | None

Optional explicit path to dataset.csv

None
targets_path str | None

Optional explicit path to targets.csv

None
compounds_path str | None

Optional explicit path to compounds.csv

None

Returns:

Type Description
dict[str, str | None]

Dictionary mapping artifact keys to file paths (or None if not found)

Source code in src/dta_gnn/exporters/artifacts.py
def collect_artifacts(
    *,
    run_dir: str | None,
    dataset_path: str | None = None,
    targets_path: str | None = None,
    compounds_path: str | None = None,
) -> dict[str, str | None]:
    """Collect artifact file paths from a run directory.

    Args:
        run_dir: Path to the run directory
        dataset_path: Optional explicit path to dataset.csv
        targets_path: Optional explicit path to targets.csv
        compounds_path: Optional explicit path to compounds.csv

    Returns:
        Dictionary mapping artifact keys to file paths (or None if not found)
    """
    run_path = Path(run_dir).resolve() if run_dir else None

    def _maybe(p: Path) -> str | None:
        try:
            return str(p) if p.exists() else None
        except Exception as e:
            logger.debug("Path check failed for {!r}: {}", p, e)
            return None

    # If caller didn't provide explicit paths, fall back to conventional names in run_dir.
    if run_path is not None:
        if dataset_path is None:
            dataset_path = _maybe(run_path / "dataset.csv")
        if targets_path is None:
            targets_path = _maybe(run_path / "targets.csv")
        if compounds_path is None:
            compounds_path = _maybe(run_path / "compounds.csv")

        metadata_path = _maybe(run_path / "metadata.json")

        # Model artifacts (RandomForest baseline)
        model_path = _maybe(run_path / "model_rf.pkl")
        model_metrics_path = _maybe(run_path / "model_metrics.json")
        model_predictions_path = _maybe(run_path / "model_predictions.csv")

        # Model artifacts (GNN)
        gnn_model_path = _maybe(run_path / "model_gnn.pt")
        gnn_metrics_path = _maybe(run_path / "model_metrics_gnn.json")
        gnn_predictions_path = _maybe(run_path / "model_predictions_gnn.csv")

        encoder_path = _maybe(run_path / "encoder_gnn.pt")
        encoder_config_path = _maybe(run_path / "encoder_gnn_config.json")
        molecule_embeddings_path = _maybe(run_path / "molecule_embeddings.npz")

        molecule_features_path = _maybe(run_path / "molecule_features.csv")
        protein_features_path = _maybe(run_path / "protein_features.csv")

        zip_path = str(run_path / "artifacts.zip")
    else:
        metadata_path = None
        model_path = None
        model_metrics_path = None
        model_predictions_path = None
        gnn_model_path = None
        gnn_metrics_path = None
        gnn_predictions_path = None
        encoder_path = None
        encoder_config_path = None
        molecule_embeddings_path = None
        molecule_features_path = None
        protein_features_path = None

        zip_path = None

    return {
        "dataset": dataset_path,
        "targets": targets_path,
        "compounds": compounds_path,
        "metadata": metadata_path,
        "model": model_path,
        "model_metrics": model_metrics_path,
        "model_predictions": model_predictions_path,
        "model_gnn": gnn_model_path,
        "model_metrics_gnn": gnn_metrics_path,
        "model_predictions_gnn": gnn_predictions_path,
        "encoder_gnn": encoder_path,
        "encoder_gnn_config": encoder_config_path,
        "molecule_embeddings": molecule_embeddings_path,
        "molecule_features": molecule_features_path,
        "protein_features": protein_features_path,
        "zip": zip_path,
    }

write_artifacts_zip

write_artifacts_zip(*, zip_path: str | None, paths: list[str | None]) -> str | None

Create a zip file from a list of artifact paths.

Parameters:

Name Type Description Default
zip_path str | None

Path where the zip file should be created

required
paths list[str | None]

List of file paths to include in the zip

required

Returns:

Type Description
str | None

Path to the created zip file, or None if creation failed

Source code in src/dta_gnn/exporters/artifacts.py
def write_artifacts_zip(
    *, zip_path: str | None, paths: list[str | None]
) -> str | None:
    """Create a zip file from a list of artifact paths.

    Args:
        zip_path: Path where the zip file should be created
        paths: List of file paths to include in the zip

    Returns:
        Path to the created zip file, or None if creation failed
    """
    if not zip_path:
        return None
    try:
        zpath = Path(zip_path)
        zpath.parent.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zpath, "w", compression=zipfile.ZIP_DEFLATED) as zf:
            for p in paths:
                if p and os.path.exists(p):
                    zf.write(p, arcname=os.path.basename(p))
        return str(zpath)
    except Exception as e:
        logger.error("Failed to create artifact zip at {!r}: {}", zip_path, e)
        return None

write_artifacts_zip_from_manifest

write_artifacts_zip_from_manifest(*, artifacts: dict[str, str | None]) -> str | None

Create a zip file from an artifacts manifest dictionary.

Parameters:

Name Type Description Default
artifacts dict[str, str | None]

Dictionary mapping artifact keys to file paths

required

Returns:

Type Description
str | None

Path to the created zip file, or None if creation failed

Source code in src/dta_gnn/exporters/artifacts.py
def write_artifacts_zip_from_manifest(
    *, artifacts: dict[str, str | None]
) -> str | None:
    """Create a zip file from an artifacts manifest dictionary.

    Args:
        artifacts: Dictionary mapping artifact keys to file paths

    Returns:
        Path to the created zip file, or None if creation failed
    """
    zip_path = artifacts.get("zip")
    keys = artifact_keys_in_zip()
    paths = [artifacts.get(k) for k in keys]
    return write_artifacts_zip(zip_path=zip_path, paths=paths)

artifacts_table

artifacts_table(artifacts: dict[str, str | None]) -> pd.DataFrame

Create a DataFrame table of artifacts for UI display.

Parameters:

Name Type Description Default
artifacts dict[str, str | None]

Dictionary mapping artifact keys to file paths

required

Returns:

Type Description
DataFrame

DataFrame with columns 'artifact' and 'path'

Source code in src/dta_gnn/exporters/artifacts.py
def artifacts_table(artifacts: dict[str, str | None]) -> pd.DataFrame:
    """Create a DataFrame table of artifacts for UI display.

    Args:
        artifacts: Dictionary mapping artifact keys to file paths

    Returns:
        DataFrame with columns 'artifact' and 'path'
    """
    return pd.DataFrame(
        [
            {"artifact": "dataset.csv", "path": artifacts.get("dataset") or ""},
            {"artifact": "targets.csv", "path": artifacts.get("targets") or ""},
            {"artifact": "compounds.csv", "path": artifacts.get("compounds") or ""},
            {"artifact": "metadata.json", "path": artifacts.get("metadata") or ""},
            {
                "artifact": "molecule_features.csv",
                "path": artifacts.get("molecule_features") or "",
            },
            {"artifact": "model_rf.pkl", "path": artifacts.get("model") or ""},
            {
                "artifact": "model_metrics.json",
                "path": artifacts.get("model_metrics") or "",
            },
            {
                "artifact": "model_predictions.csv",
                "path": artifacts.get("model_predictions") or "",
            },
            {"artifact": "model_gnn.pt", "path": artifacts.get("model_gnn") or ""},
            {
                "artifact": "model_metrics_gnn.json",
                "path": artifacts.get("model_metrics_gnn") or "",
            },
            {
                "artifact": "model_predictions_gnn.csv",
                "path": artifacts.get("model_predictions_gnn") or "",
            },
            {"artifact": "encoder_gnn.pt", "path": artifacts.get("encoder_gnn") or ""},
            {
                "artifact": "encoder_gnn_config.json",
                "path": artifacts.get("encoder_gnn_config") or "",
            },
            {
                "artifact": "molecule_embeddings.npz",
                "path": artifacts.get("molecule_embeddings") or "",
            },
            {"artifact": "artifacts.zip", "path": artifacts.get("zip") or ""},
        ]
    )

artifact_keys_in_zip

artifact_keys_in_zip() -> list[str]

Stable list of artifact keys included in artifacts.zip.

Keep this centralized so handlers don't duplicate long lists.

Source code in src/dta_gnn/exporters/artifacts.py
def artifact_keys_in_zip() -> list[str]:
    """Stable list of artifact keys included in artifacts.zip.

    Keep this centralized so handlers don't duplicate long lists.
    """
    return [
        "dataset",
        "targets",
        "compounds",
        "metadata",
        "molecule_features",
        "protein_features",
        "model",
        "model_metrics",
        "model_predictions",
        "model_gnn",
        "model_metrics_gnn",
        "model_predictions_gnn",
        "encoder_gnn",
        "encoder_gnn_config",
        "molecule_embeddings",
    ]

generate_dataset_card

generate_dataset_card(df: DataFrame, metadata: Dict[str, Any], output_path: str)

Generate a markdown dataset card.

Source code in src/dta_gnn/exporters/card.py
def generate_dataset_card(df: pd.DataFrame, metadata: Dict[str, Any], output_path: str):
    """
    Generate a markdown dataset card.
    """
    # Calculate label statistics
    if 'label' in df.columns and len(df['label'].dropna()) > 0:
        label_min = df['label'].min()
        label_max = df['label'].max()
        label_mean = df['label'].mean()
        label_range = f"{label_min:.2f} - {label_max:.2f} (pChEMBL)"
        mean_affinity = f"{label_mean:.2f} (pChEMBL)"
    else:
        label_range = 'N/A'
        mean_affinity = 'N/A'

    card = f"""# Dataset Card

## Metadata
- **Target IDs**: {metadata.get('targets')}
- **Source**: {metadata.get('source')} (Web/SQLite)
- **Date**: {pd.Timestamp.now()}

## Statistics
- **Total Samples**: {len(df)}
- **Label Range**: {label_range}
- **Mean Affinity**: {mean_affinity}
- **Columns**: {', '.join(df.columns)}

## Split Information
- **Strategy**: {metadata.get('split_method')}
"""

    if "split" in df.columns:
        counts = df["split"].value_counts().to_dict()
        card += f"""
### Split Counts
- **Train**: {counts.get('train', 0)}
- **Val**: {counts.get('val', 0)}
- **Test**: {counts.get('test', 0)}
"""

    card += """
## Preprocessing
- **Deduplication**: Median aggregation
- **Standardization**: Converted to pChEMBL, dropped invalid units.
- **Audits**: Leakage check performed.

## Leakage Audit
"""
    if "audit" in metadata:
        card += f"```json\n{metadata['audit']}\n```"

    with open(output_path, "w") as f:
        f.write(card)

Visualisation (dta_gnn.visualization)

plot_activity_distribution

plot_activity_distribution(df: DataFrame, title: str = 'Activity Distribution') -> plt.Figure

Plot histogram of pChEMBL values.

Source code in src/dta_gnn/visualization.py
def plot_activity_distribution(
    df: pd.DataFrame, title: str = "Activity Distribution"
) -> plt.Figure:
    """
    Plot histogram of pChEMBL values.
    """
    fig, ax = plt.subplots(figsize=(8, 5))
    if "pchembl_value" in df.columns:
        # Check if we should do interval-based plotting (for regression/continuous)
        # We can infer this if 'label' is float and matches pchembl_value, or just always do proper histogram?
        # User asked for "interval based".
        # sns.histplot already does binning. Maybe they want explicit integer/0.5 bars?

        # Create bins of size 0.5
        df_plot = df.copy()
        df_plot["pchembl_bin"] = (df_plot["pchembl_value"] * 2).round() / 2

        # Count per bin
        counts = df_plot["pchembl_bin"].value_counts().sort_index().reset_index()
        counts.columns = ["pChEMBL Interval", "Count"]

        sns.barplot(
            data=counts,
            x="pChEMBL Interval",
            y="Count",
            ax=ax,
            palette="viridis",
            hue="pChEMBL Interval",
            legend=False,
        )
        ax.set_title(title)
        ax.set_xlabel("pChEMBL Value (Binned 0.5)")
        # Rotate x labels if too many
        if len(counts) > 10:
            plt.xticks(rotation=45)
    else:
        ax.text(0.5, 0.5, "No pChEMBL values found", ha="center")

    plt.tight_layout()
    return fig

plot_split_sizes

plot_split_sizes(df: DataFrame) -> plt.Figure

Plot bar chart of split sizes.

Source code in src/dta_gnn/visualization.py
def plot_split_sizes(df: pd.DataFrame) -> plt.Figure:
    """
    Plot bar chart of split sizes.
    """
    fig, ax = plt.subplots(figsize=(6, 4))
    if "split" in df.columns:
        counts = df["split"].value_counts().reset_index()
        counts.columns = ["Split", "Count"]
        sns.barplot(
            data=counts,
            x="Split",
            y="Count",
            hue="Split",
            ax=ax,
            palette="muted",
            legend=False,
        )
        ax.set_title("Dataset Splits")
        # Add labels
        for i, row in counts.iterrows():
            ax.text(i, row.Count, str(row.Count), ha="center", va="bottom")
    else:
        ax.text(0.5, 0.5, "No split info found", ha="center")

    plt.tight_layout()
    return fig

plot_chemical_space

plot_chemical_space(smiles_data: Union[dict, list], method: str = 't-SNE', radius: int = 2, n_bits: int = 1024, n_components: int = 2, perplexity: int = 30, learning_rate: float = 200.0, random_state: int = 42) -> plt.Figure

Visualize chemical space using Morgan fingerprints and dimensionality reduction. Acceps a dictionary {group_name: [smiles]} or a flat list of SMILES.

Source code in src/dta_gnn/visualization.py
def plot_chemical_space(
    smiles_data: Union[dict, list],
    method: str = "t-SNE",
    radius: int = 2,
    n_bits: int = 1024,
    n_components: int = 2,
    perplexity: int = 30,
    learning_rate: float = 200.0,
    random_state: int = 42,
) -> plt.Figure:
    """
    Visualize chemical space using Morgan fingerprints and dimensionality reduction.
    Acceps a dictionary {group_name: [smiles]} or a flat list of SMILES.
    """
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
    import numpy as np
    import seaborn as sns

    # Standardize input to dict
    if isinstance(smiles_data, list):
        smiles_dict = {"Custom": smiles_data}
    else:
        smiles_dict = smiles_data

    # 1. Generate Fingerprints & Track Groups
    fps = []
    labels = []

    # Use GetMorganGenerator if available
    try:
        from rdkit.Chem import AllChem

        generator = AllChem.GetMorganGenerator(radius=radius, fpSize=n_bits)
        use_generator = True
    except AttributeError:
        use_generator = False

    for group, smiles_list in smiles_dict.items():
        for smi in smiles_list:
            if not isinstance(smi, str) or not smi.strip():
                continue
            try:
                mol = Chem.MolFromSmiles(smi)
                if mol:
                    if use_generator:
                        fp = generator.GetFingerprint(mol)
                    else:
                        fp = AllChem.GetMorganFingerprintAsBitVect(
                            mol, radius, nBits=n_bits
                        )

                    bits = [int(x) for x in fp.ToBitString()]
                    fps.append(bits)
                    labels.append(group)
            except Exception:
                continue

    if not fps:
        fig, ax = plt.subplots()
        ax.text(0.5, 0.5, "No valid SMILES found.", ha="center")
        return fig

    X = np.array(fps)

    # 2. Dimensionality Reduction
    if method == "t-SNE":
        n_samples = X.shape[0]
        eff_perplexity = min(perplexity, n_samples - 1) if n_samples > 1 else 1
        init_method = "pca" if n_samples > n_components else "random"

        reducer = TSNE(
            n_components=n_components,
            perplexity=eff_perplexity,
            learning_rate=learning_rate,
            random_state=random_state,
            init=init_method,
        )
        X_emb = reducer.fit_transform(X)
    elif method == "PCA":
        n_samples = X.shape[0]
        n_comps = min(n_components, n_samples)
        reducer = PCA(n_components=n_comps, random_state=random_state)
        X_emb = reducer.fit_transform(X)
    else:
        raise ValueError(f"Unknown method: {method}")

    # 3. Plot
    fig, ax = plt.subplots(figsize=(10, 6))

    if X_emb.shape[1] >= 2:
        sns.scatterplot(
            x=X_emb[:, 0], y=X_emb[:, 1], hue=labels, ax=ax, alpha=0.7, palette="tab10"
        )
    else:
        sns.scatterplot(
            x=X_emb[:, 0],
            y=[0] * len(X_emb),
            hue=labels,
            ax=ax,
            alpha=0.7,
            palette="tab10",
        )

    ax.set_title(f"Chemical Space ({method})")
    ax.set_xlabel("Component 1")
    ax.set_ylabel("Component 2")
    # Move legend outside if possible
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    return fig