Source code for pact_testgen.files
import json
from ast import parse
from _ast import Module, FunctionDef
from io import StringIO
from typing import Generator, List, Tuple
from pathlib import Path
from enum import Enum, auto
try:
from ast import unparse
except ImportError:
unparse = None
from pact_testgen.models import Pact
[docs]class ProviderStateFileOutcome(Enum):
WROTE_NEW = auto()
MERGED = auto()
LEFT_EXISTING = auto()
NO_CHANGES_REQUIRED = auto()
[docs]def load_pact_file(path: str) -> Pact:
"""Loads the file at the supplied path into a Pact model"""
with open(Path(path), "r") as f:
pact = json.load(f)
return Pact(**pact)
[docs]def write_test_file(testfile: str, path: Path):
with open(path, "w") as f:
f.write(testfile)
[docs]def write_provider_state_file(
provider_state_file: str, path: Path, merge_file=False
) -> ProviderStateFileOutcome:
# TODO: Support appending new provider state functions.
# For now, don't write the file if it already exists
exists = path.exists()
if exists:
if merge_file:
with open(path, "r+") as target_handle:
target = target_handle.read()
final, num_added_functions = merge(target, provider_state_file)
target_handle.seek(0)
target_handle.write(final)
if num_added_functions:
return ProviderStateFileOutcome.MERGED
return ProviderStateFileOutcome.NO_CHANGES_REQUIRED
return ProviderStateFileOutcome.LEFT_EXISTING
with open(path, "w") as f:
f.write(provider_state_file)
return ProviderStateFileOutcome.WROTE_NEW
[docs]def get_functions(mod: Module) -> Generator[FunctionDef, None, None]:
for node in mod.body:
if isinstance(node, FunctionDef):
yield node
[docs]def merge_is_available():
return unparse is not None
[docs]def merge(target: str, src: str) -> Tuple[str, int]:
"""
Merge "src" code into "target".
Only add functions from src that aren't
already present in target.
Returns the merged file as a string, and the number
of functions that were added.
"""
if not merge_is_available:
raise RuntimeError("Cannot merge. No unparse function available")
target_ast = parse(target)
target_buffer = StringIO()
target_buffer.write(target)
src_ast = parse(src)
existing_function_names = set([func.name for func in get_functions(target_ast)])
function_bodies_to_add: List[str] = []
for funcdef in get_functions(src_ast):
if funcdef.name not in existing_function_names:
function_bodies_to_add.append(unparse(funcdef))
if function_bodies_to_add:
target_buffer.write("\n")
target_buffer.write("\n\n".join(function_bodies_to_add))
target_buffer.write("\n\n")
return target_buffer.getvalue(), len(function_bodies_to_add)