Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/bagit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,12 +896,12 @@
if processes == 1:
hash_results = [_calc_hashes(i) for i in args]
else:
pool = multiprocessing.Pool(
processes if processes else None, initializer=worker_init
hash_results = _multiprocessing_pool_map(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering whether we might want to refactor this to avoid buffering everything in memory (i.e. using Pool.imap so it could yield from in the utility function but I don't think it's very common that people are using this in a way where they need to worry about total RAM or can meaningfully process the results before they're all ready.

_calc_hashes,
args,
processes if processes else None,
initializer=worker_init,
)
hash_results = pool.map(_calc_hashes, args)
pool.close()
pool.join()

# Any unhandled exceptions are probably fatal
except:
Expand Down Expand Up @@ -1037,6 +1037,25 @@
signal.signal(signal.SIGINT, signal.SIG_IGN)


def _multiprocessing_pool_map(func, iterable, processes, initializer=None):
"""Run ``Pool.map()`` and always clean up the pool.

This ensures worker processes are closed or terminated, then joined, under
all conditions.
"""
pool = multiprocessing.Pool(processes=processes, initializer=initializer)
try:
results = pool.map(func, iterable)
except BaseException:
pool.terminate()
raise
else:
pool.close()
return results
finally:
pool.join()


# The Unicode normalization form used here doesn't matter – all we care about
# is consistency since the input value will be preserved:

Expand Down Expand Up @@ -1245,10 +1264,9 @@
manifest_line_generator = partial(generate_manifest_lines, algorithms=algorithms)

if processes > 1:
pool = multiprocessing.Pool(processes=processes)
checksums = pool.map(manifest_line_generator, _walk(data_dir))
pool.close()
pool.join()
checksums = _multiprocessing_pool_map(
manifest_line_generator, _walk(data_dir), processes=processes
)
else:
checksums = [manifest_line_generator(i) for i in _walk(data_dir)]

Expand Down Expand Up @@ -1583,7 +1601,7 @@
else:
LOGGER.info(_("%s is valid"), bag_dir)
except BagError as e:
LOGGER.error(

Check failure on line 1604 in src/bagit/__init__.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=LibraryOfCongress_bagit-python&issues=AZ7WfiStvaLYFV3NH828&open=AZ7WfiStvaLYFV3NH828&pullRequest=208
_("%(bag)s is invalid: %(error)s"), {"bag": bag_dir, "error": e}
)
rc = 1
Expand All @@ -1598,7 +1616,7 @@
checksums=args.checksums,
)
except Exception as exc:
LOGGER.error(

Check failure on line 1619 in src/bagit/__init__.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=LibraryOfCongress_bagit-python&issues=AZ7WfiStvaLYFV3NH829&open=AZ7WfiStvaLYFV3NH829&pullRequest=208
_("Failed to create bag in %(bag_directory)s: %(error)s"),
{"bag_directory": bag_dir, "error": exc},
exc_info=True,
Expand Down
35 changes: 33 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import tempfile
import unicodedata
import unittest
from io import StringIO
from os.path import join as j

from unittest import mock
from io import StringIO

import bagit

Expand Down Expand Up @@ -458,6 +457,23 @@ def validate(self, bag, *args, **kwargs):
bag, *args, processes=2, **kwargs
)

@mock.patch("bagit.multiprocessing.Pool")
def test_validate_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
pool.return_value.map.side_effect = RuntimeError("boom")
bag = bagit.make_bag(self.tmpdir)

with self.assertRaises(RuntimeError):
self.validate(bag)

self.assertEqual(
pool.return_value.mock_calls,
[
mock.call.map(mock.ANY, mock.ANY),
mock.call.terminate(),
mock.call.join(),
],
)

@mock.patch("bagit.multiprocessing.Pool")
def test_validate_pool_error(self, pool):
# Simulate the Pool constructor raising a RuntimeError.
Expand Down Expand Up @@ -745,6 +761,21 @@ def test_make_bag_multiprocessing(self):
bagit.make_bag(self.tmpdir, processes=2)
self.assertTrue(os.path.isdir(j(self.tmpdir, "data")))

@mock.patch("bagit.multiprocessing.Pool")
def test_make_bag_multiprocessing_terminates_and_joins_pool_on_failure(self, pool):
pool.return_value.map.side_effect = RuntimeError("boom")
with self.assertRaises(RuntimeError):
bagit.make_bag(self.tmpdir, processes=2)

self.assertEqual(
pool.return_value.mock_calls,
[
mock.call.map(mock.ANY, mock.ANY),
mock.call.terminate(),
mock.call.join(),
],
)

def test_multiple_meta_values(self):
baginfo = {"Multival-Meta": [7, 4, 8, 6, 8]}
bag = bagit.make_bag(self.tmpdir, baginfo)
Expand Down
Loading