102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
"""
|
|
Parametrize a test function by CSV file
|
|
"""
|
|
import csv
|
|
from pathlib import Path
|
|
from typing import List, Optional, TypedDict
|
|
|
|
import pytest
|
|
|
|
from pytest_csv_params.dialect import CsvParamsDefaultDialect
|
|
from pytest_csv_params.exception import (
|
|
CsvParamsDataFileInaccessible,
|
|
CsvParamsDataFileInvalid,
|
|
CsvParamsDataFileNotFound,
|
|
)
|
|
from pytest_csv_params.types import BaseDir, CsvDialect, DataCastDict, DataFile, IdColName
|
|
|
|
|
|
class TestCaseParameters(TypedDict):
|
|
"""
|
|
Type for Test Case
|
|
"""
|
|
|
|
test_id: Optional[str]
|
|
data: List
|
|
|
|
|
|
def read_csv(base_dir: BaseDir, data_file: DataFile, dialect: CsvDialect):
|
|
"""
|
|
Get Data from CSV
|
|
"""
|
|
|
|
if data_file is None:
|
|
raise CsvParamsDataFileInvalid("Data file is None") from None
|
|
csv_file = Path(data_file)
|
|
if not csv_file.is_absolute():
|
|
if base_dir is not None:
|
|
csv_file = Path(base_dir) / csv_file
|
|
if not csv_file.exists() or not csv_file.is_file():
|
|
raise CsvParamsDataFileNotFound(f"Cannot find file: {str(csv_file)}") from None
|
|
csv_lines = []
|
|
try:
|
|
with open(csv_file, newline="", encoding="utf-8") as csv_file_handle:
|
|
csv_reader = csv.reader(csv_file_handle, dialect=dialect)
|
|
for row in csv_reader:
|
|
csv_lines.append(row)
|
|
except IOError as err: # pragma: no cover
|
|
raise CsvParamsDataFileInaccessible(f"Unable to read file: {str(csv_file)}") from err
|
|
except csv.Error as err:
|
|
raise CsvParamsDataFileInvalid("Invalid data") from err
|
|
return csv_lines
|
|
|
|
|
|
def add_parametrization(
|
|
base_dir: BaseDir = None,
|
|
data_file: DataFile = None,
|
|
id_col: IdColName = None,
|
|
data_casts: DataCastDict = None,
|
|
dialect: CsvDialect = CsvParamsDefaultDialect,
|
|
):
|
|
"""
|
|
Get data from the files and add things to the tests
|
|
"""
|
|
csv_lines = read_csv(base_dir, data_file, dialect)
|
|
if len(csv_lines) < 2:
|
|
raise CsvParamsDataFileInvalid("File does not contain a single data row") from None
|
|
id_index = -1
|
|
headers = csv_lines.pop(0)
|
|
if id_col is not None:
|
|
try:
|
|
id_index = headers.index(id_col)
|
|
del headers[id_index]
|
|
except ValueError as err:
|
|
raise CsvParamsDataFileInvalid(f"Cannot find ID column '{id_col}'") from err
|
|
if len(headers) == 0:
|
|
raise CsvParamsDataFileInvalid("File seems only to have IDs") from None
|
|
data: List[TestCaseParameters] = []
|
|
for data_line in csv_lines:
|
|
line = list(map(str, data_line))
|
|
if len(line) == 0:
|
|
continue
|
|
test_id = None
|
|
if id_index >= 0:
|
|
test_id = line.pop(id_index)
|
|
if len(headers) != len(line):
|
|
raise CsvParamsDataFileInvalid("Header and Data length mismatch") from None
|
|
if data_casts is not None:
|
|
for index, header in enumerate(headers):
|
|
caster = data_casts.get(header, None)
|
|
if caster is not None:
|
|
line[index] = caster(str(line[index]))
|
|
data.append(
|
|
{
|
|
"test_id": test_id,
|
|
"data": line,
|
|
}
|
|
)
|
|
id_data = None
|
|
if id_col is not None:
|
|
id_data = [tc_data["test_id"] for tc_data in data]
|
|
return pytest.mark.parametrize(headers, [tc_data["data"] for tc_data in data], ids=id_data)
|