writefile should write to a temporary file first

This commit is contained in:
Jan Petykiewicz 2023-01-26 13:54:00 -08:00
parent 6cbdd7930d
commit 6172abf77c
4 changed files with 77 additions and 52 deletions

View File

@ -6,7 +6,8 @@ Notes:
* ezdxf sets creation time, write time, $VERSIONGUID, and $FINGERPRINTGUID * ezdxf sets creation time, write time, $VERSIONGUID, and $FINGERPRINTGUID
to unique values, so byte-for-byte reproducibility is not achievable for now to unique values, so byte-for-byte reproducibility is not achievable for now
""" """
from typing import List, Any, Dict, Tuple, Callable, Union, Mapping, TextIO from typing import List, Any, Dict, Tuple, Callable, Union, Mapping
from typing import cast, TextIO, IO
import io import io
import logging import logging
import pathlib import pathlib
@ -15,12 +16,12 @@ import gzip
import numpy import numpy
import ezdxf import ezdxf
from .utils import is_gzipped, tmpfile
from .. import Pattern, Ref, PatternError, Label from .. import Pattern, Ref, PatternError, Label
from ..library import Library, WrapROLibrary from ..library import Library, WrapROLibrary
from ..shapes import Shape, Polygon, Path from ..shapes import Shape, Polygon, Path
from ..repetition import Grid from ..repetition import Grid
from ..utils import rotation_matrix_2d, layer_t from ..utils import rotation_matrix_2d, layer_t
from .utils import is_gzipped
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -125,22 +126,22 @@ def writefile(
""" """
path = pathlib.Path(filename) path = pathlib.Path(filename)
streams: Tuple[Any, ...] gz_stream: IO[bytes]
stream: TextIO with tmpfile(path) as base_stream:
if path.suffix == '.gz': streams: Tuple[Any, ...] = (base_stream,)
base_stream = open(path, mode='wb') if path.suffix == '.gz':
gz_stream = gzip.GzipFile(filename='', mtime=0, fileobj=base_stream) gz_stream = cast(IO[bytes], gzip.GzipFile(filename='', mtime=0, fileobj=base_stream))
streams = (gz_stream,) + streams
else:
gz_stream = base_stream
stream = io.TextIOWrapper(gz_stream) # type: ignore stream = io.TextIOWrapper(gz_stream) # type: ignore
streams = (stream, gz_stream, base_stream) streams = (stream,) + streams
else:
stream = open(path, mode='wt')
streams = (stream,)
try: try:
write(library, top_name, stream, *args, **kwargs) write(library, top_name, stream, *args, **kwargs)
finally: finally:
for ss in streams: for ss in streams:
ss.close() ss.close()
def readfile( def readfile(

View File

@ -20,7 +20,7 @@ Notes:
* Gzip modification time is set to 0 (start of current epoch, usually 1970-01-01) * Gzip modification time is set to 0 (start of current epoch, usually 1970-01-01)
""" """
from typing import List, Dict, Tuple, Callable, Union, Iterable, Mapping from typing import List, Dict, Tuple, Callable, Union, Iterable, Mapping
from typing import BinaryIO, cast, Optional, Any from typing import IO, cast, Optional, Any
import io import io
import mmap import mmap
import logging import logging
@ -34,7 +34,7 @@ from numpy.typing import ArrayLike, NDArray
import klamath import klamath
from klamath import records from klamath import records
from .utils import is_gzipped from .utils import is_gzipped, tmpfile
from .. import Pattern, Ref, PatternError, LibraryError, Label, Shape from .. import Pattern, Ref, PatternError, LibraryError, Label, Shape
from ..shapes import Polygon, Path from ..shapes import Polygon, Path
from ..repetition import Grid from ..repetition import Grid
@ -59,7 +59,7 @@ def rint_cast(val: ArrayLike) -> NDArray[numpy.int32]:
def write( def write(
library: Mapping[str, Pattern], library: Mapping[str, Pattern],
stream: BinaryIO, stream: IO[bytes],
meters_per_unit: float, meters_per_unit: float,
logical_units_per_unit: float = 1, logical_units_per_unit: float = 1,
library_name: str = 'masque-klamath', library_name: str = 'masque-klamath',
@ -142,19 +142,19 @@ def writefile(
""" """
path = pathlib.Path(filename) path = pathlib.Path(filename)
base_stream = open(path, mode='wb') with tmpfile(path) as base_stream:
streams: Tuple[Any, ...] = (base_stream,) streams: Tuple[Any, ...] = (base_stream,)
if path.suffix == '.gz': if path.suffix == '.gz':
stream = cast(BinaryIO, gzip.GzipFile(filename='', mtime=0, fileobj=base_stream)) stream = cast(IO[bytes], gzip.GzipFile(filename='', mtime=0, fileobj=base_stream))
streams = (stream,) + streams streams = (stream,) + streams
else: else:
stream = base_stream stream = base_stream
try: try:
write(library, stream, *args, **kwargs) write(library, stream, *args, **kwargs)
finally: finally:
for ss in streams: for ss in streams:
ss.close() ss.close()
def readfile( def readfile(
@ -184,7 +184,7 @@ def readfile(
def read( def read(
stream: BinaryIO, stream: IO[bytes],
raw_mode: bool = True, raw_mode: bool = True,
) -> Tuple[Dict[str, Pattern], Dict[str, Any]]: ) -> Tuple[Dict[str, Pattern], Dict[str, Any]]:
""" """
@ -220,7 +220,7 @@ def read(
return patterns_dict, library_info return patterns_dict, library_info
def _read_header(stream: BinaryIO) -> Dict[str, Any]: def _read_header(stream: IO[bytes]) -> Dict[str, Any]:
""" """
Read the file header and create the library_info dict. Read the file header and create the library_info dict.
""" """
@ -234,7 +234,7 @@ def _read_header(stream: BinaryIO) -> Dict[str, Any]:
def read_elements( def read_elements(
stream: BinaryIO, stream: IO[bytes],
raw_mode: bool = True, raw_mode: bool = True,
) -> Pattern: ) -> Pattern:
""" """
@ -509,7 +509,7 @@ def _labels_to_texts(labels: List[Label]) -> List[klamath.elements.Text]:
def load_library( def load_library(
stream: BinaryIO, stream: IO[bytes],
*, *,
full_load: bool = False, full_load: bool = False,
postprocess: Optional[Callable[[Library, str, Pattern], Pattern]] = None postprocess: Optional[Callable[[Library, str, Pattern], Pattern]] = None
@ -595,7 +595,7 @@ def load_libraryfile(
Additional library info (dict, same format as from `read`). Additional library info (dict, same format as from `read`).
""" """
path = pathlib.Path(filename) path = pathlib.Path(filename)
stream: BinaryIO stream: IO[bytes]
if is_gzipped(path): if is_gzipped(path):
if mmap: if mmap:
logger.info('Asked to mmap a gzipped file, reading into memory instead...') logger.info('Asked to mmap a gzipped file, reading into memory instead...')

View File

@ -15,7 +15,7 @@ Notes:
* Gzip modification time is set to 0 (start of current epoch, usually 1970-01-01) * Gzip modification time is set to 0 (start of current epoch, usually 1970-01-01)
""" """
from typing import List, Any, Dict, Tuple, Callable, Union, Iterable from typing import List, Any, Dict, Tuple, Callable, Union, Iterable
from typing import BinaryIO, Mapping, Optional, cast, Sequence from typing import IO, Mapping, Optional, cast, Sequence
import logging import logging
import pathlib import pathlib
import gzip import gzip
@ -28,7 +28,7 @@ import fatamorgana
import fatamorgana.records as fatrec import fatamorgana.records as fatrec
from fatamorgana.basic import PathExtensionScheme, AString, NString, PropStringReference from fatamorgana.basic import PathExtensionScheme, AString, NString, PropStringReference
from .utils import is_gzipped from .utils import is_gzipped, tmpfile
from .. import Pattern, Ref, PatternError, LibraryError, Label, Shape from .. import Pattern, Ref, PatternError, LibraryError, Label, Shape
from ..library import WrapLibrary, MutableLibrary from ..library import WrapLibrary, MutableLibrary
from ..shapes import Polygon, Path, Circle from ..shapes import Polygon, Path, Circle
@ -150,7 +150,7 @@ def build(
def write( def write(
library: Mapping[str, Pattern], # NOTE: Pattern here should be treated as immutable! library: Mapping[str, Pattern], # NOTE: Pattern here should be treated as immutable!
stream: BinaryIO, stream: IO[bytes],
*args, *args,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -187,19 +187,19 @@ def writefile(
""" """
path = pathlib.Path(filename) path = pathlib.Path(filename)
base_stream = open(path, mode='wb') with tmpfile(path) as base_stream:
streams: Tuple[Any, ...] = (base_stream,) streams: Tuple[Any, ...] = (base_stream,)
if path.suffix == '.gz': if path.suffix == '.gz':
stream = cast(BinaryIO, gzip.GzipFile(filename='', mtime=0, fileobj=base_stream)) stream = cast(IO[bytes], gzip.GzipFile(filename='', mtime=0, fileobj=base_stream))
streams += (stream,) streams += (stream,)
else: else:
stream = base_stream stream = base_stream
try: try:
write(library, stream, *args, **kwargs) write(library, stream, *args, **kwargs)
finally: finally:
for ss in streams: for ss in streams:
ss.close() ss.close()
def readfile( def readfile(
@ -229,7 +229,7 @@ def readfile(
def read( def read(
stream: BinaryIO, stream: IO[bytes],
) -> Tuple[Dict[str, Pattern], Dict[str, Any]]: ) -> Tuple[Dict[str, Pattern], Dict[str, Any]]:
""" """
Read a OASIS file and translate it into a dict of Pattern objects. OASIS cells are Read a OASIS file and translate it into a dict of Pattern objects. OASIS cells are

View File

@ -1,9 +1,13 @@
""" """
Helper functions for file reading and writing Helper functions for file reading and writing
""" """
from typing import Union, IO, Iterator
import re import re
import pathlib import pathlib
import logging import logging
import tempfile
import shutil
from contextlib import contextmanager
from .. import Pattern, PatternError from .. import Pattern, PatternError
from ..shapes import Polygon, Path from ..shapes import Polygon, Path
@ -55,3 +59,23 @@ def is_gzipped(path: pathlib.Path) -> bool:
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
magic_bytes = stream.read(2) magic_bytes = stream.read(2)
return magic_bytes == b'\x1f\x8b' return magic_bytes == b'\x1f\x8b'
@contextmanager
def tmpfile(path: Union[str, pathlib.Path]) -> Iterator[IO[bytes]]:
"""
Context manager which allows you to write to a temporary file,
and move that file into its final location only after the write
has finished.
"""
path = pathlib.Path(path)
suffixes = ''.join(path.suffixes)
with tempfile.NamedTemporaryFile(suffix=suffixes, delete=False) as tmp_stream:
yield tmp_stream
try:
shutil.move(tmp_stream.name, path)
finally:
pathlib.Path(tmp_stream.name).unlink(missing_ok=True)