Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds input of NotebookNode object directly #670

Merged
merged 15 commits into from
Aug 12, 2022
4 changes: 2 additions & 2 deletions papermill/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def execute_notebook(

Parameters
----------
input_path : str or Path
Path to input notebook
input_path : str or Path or nbformat.NotebookNode
Path to input notebook or NotebookNode object of notebook
output_path : str or Path or None
Path to save executed notebook. If None, no file will be saved
parameters : dict, optional
Expand Down
112 changes: 73 additions & 39 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,50 +96,14 @@ def __init__(self):
self.reset()

def read(self, path, extensions=['.ipynb', '.json']):
if path == '-':
return sys.stdin.read()

if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions
):
warnings.warn(
"The specified input file ({}) does not end in one of {}".format(path, extensions)
)
# Handle https://github.com/nteract/papermill/issues/317
notebook_metadata = self.get_handler(path).read(path)
notebook_metadata = self.get_handler(path, extensions).read(path)
if isinstance(notebook_metadata, (bytes, bytearray)):
return notebook_metadata.decode('utf-8')
return notebook_metadata

def write(self, buf, path, extensions=['.ipynb', '.json']):
if path is None:
return
if path == '-':
try:
return sys.stdout.buffer.write(buf.encode('utf-8'))
except AttributeError:
# Originally required by https://github.com/nteract/papermill/issues/420
# Support Buffer.io objects
return sys.stdout.write(buf.encode('utf-8'))

return sys.stdout.buffer.write(buf.encode('utf-8'))

# Usually no return object here
if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions
):
warnings.warn(
"The specified output file ({}) does not end in one of {}".format(path, extensions)
)
return self.get_handler(path).write(buf, path)
return self.get_handler(path, extensions).write(buf, path)

def listdir(self, path):
return self.get_handler(path).listdir(path)
Expand All @@ -159,10 +123,44 @@ def register_entry_points(self):
for entrypoint in entrypoints.get_group_all("papermill.io"):
self.register(entrypoint.name, entrypoint.load())

def get_handler(self, path):
def get_handler(self, path, extensions=None):
'''Get I/O Handler based on a notebook path

Parameters
----------
path : str or nbformat.NotebookNode or None
extensions : list of str, optional
Required file extension options for the path (if path is a string), which
will log a warning if there is no match. Defaults to None, which does not
check for any extensions

Raises
------
PapermillException: If a valid I/O handler could not be found for the input path

Returns
-------
I/O Handler
'''
if path is None:
return NoIOHandler()

if isinstance(path, nbformat.NotebookNode):
return NotebookNodeHandler()

if extensions:
if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext)
for ext in extensions
):
warnings.warn(
"The specified file ({}) does not end in one of {}".format(path, extensions)
)

local_handler = None
for scheme, handler in self._handlers:
if scheme == 'local':
Expand Down Expand Up @@ -411,6 +409,41 @@ def pretty_path(self, path):
return path


class StreamHandler(object):
'''Handler for Stdin/Stdout streams'''
def read(self, path):
return sys.stdin.read()

def listdir(self, path):
raise PapermillException('listdir is not supported by Stream Handler')

def write(self, buf, path):
try:
return sys.stdout.buffer.write(buf.encode('utf-8'))
except AttributeError:
# Originally required by https://github.com/nteract/papermill/issues/420
# Support Buffer.io objects
return sys.stdout.write(buf.encode('utf-8'))

def pretty_path(self, path):
return path


class NotebookNodeHandler(object):
'''Handler for input_path of nbformat.NotebookNode object'''
def read(self, path):
return nbformat.writes(path)

def listdir(self, path):
raise PapermillException('listdir is not supported by NotebookNode Handler')

def write(self, buf, path):
raise PapermillException('write is not supported by NotebookNode Handler')

def pretty_path(self, path):
return 'NotebookNode object'


class NoIOHandler(object):
'''Handler for output_path of None - intended to not write anything'''

Expand Down Expand Up @@ -448,6 +481,7 @@ class NoDatesSafeLoader(yaml.SafeLoader):
papermill_io.register("hdfs://", HDFSHandler())
papermill_io.register("http://github.com/", GithubHandler())
papermill_io.register("https://github.com/", GithubHandler())
papermill_io.register("-", StreamHandler())
papermill_io.register_entry_points()


Expand Down
9 changes: 5 additions & 4 deletions papermill/parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def parameterize_path(path, parameters):

Parameters
----------
path : string or None
Path with optional parameters, as a python format string
path : string or nbformat.NotebookNode or None
Path with optional parameters, as a python format string. If path is a NotebookNode
or None, the path is returned without modification
parameters : dict or None
Arbitrary keyword arguments to fill in the path
"""
if path is None:
return
if path is None or isinstance(path, nbformat.NotebookNode):
return path

if parameters is None:
parameters = {}
Expand Down
20 changes: 20 additions & 0 deletions papermill/tests/notebooks/test_notebooknode_io.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": ["print('Hello World')"]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
16 changes: 16 additions & 0 deletions papermill/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from pathlib import Path

import nbformat
from nbformat import validate

from .. import engines, translators
Expand Down Expand Up @@ -436,3 +437,18 @@ def test_custom_kernel_name_and_language(self, translate_parameters, execute_man
)
self.assertEqual(execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel"))
self.assertEqual(translate_parameters.call_args[0], (ANY, 'my_custom_language', {"msg": "fake msg"}, ANY))


class TestNotebookNodeInput(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
self.result_path = os.path.join(self.test_dir.name, 'output.ipynb')

def tearDown(self):
self.test_dir.cleanup()

def test_notebook_node_input(self):
input_nb = nbformat.read(get_notebook_path('simple_execute.ipynb'), as_version=4)
execute_notebook(input_nb, self.result_path, {'msg': 'Hello'})
test_nb = nbformat.read(self.result_path, as_version=4)
self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'})
63 changes: 61 additions & 2 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
LocalHandler,
NoIOHandler,
ADLHandler,
NotebookNodeHandler,
StreamHandler,
PapermillIO,
read_yaml_file,
papermill_io,
local_file_io_cwd,
)
from ..exceptions import PapermillException
from . import get_notebook_path

FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')

Expand Down Expand Up @@ -94,6 +97,10 @@ def test_get_local_handler(self):
def test_get_no_io_handler(self):
self.assertIsInstance(self.papermill_io.get_handler(None), NoIOHandler)

def test_get_notebook_node_handler(self):
test_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler)

def test_entrypoint_register(self):

fake_entrypoint = Mock(load=Mock())
Expand Down Expand Up @@ -155,7 +162,7 @@ def test_read_yaml_with_invalid_file_extension(self):
def test_read_stdin(self):
file_content = u'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
with patch('sys.stdin', io.StringIO(file_content)):
self.assertEqual(self.papermill_io.read("-"), file_content)
self.assertEqual(self.old_papermill_io.read("-"), file_content)

def test_listdir(self):
self.assertEqual(self.papermill_io.listdir("fake/path"), ["fake", "contents"])
Expand All @@ -178,7 +185,7 @@ def test_write_stdout(self):
file_content = u'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
out = io.BytesIO()
with patch('sys.stdout', out):
self.papermill_io.write(file_content, "-")
self.old_papermill_io.write(file_content, "-")
self.assertEqual(out.getvalue(), file_content.encode('utf-8'))

def test_pretty_path(self):
Expand Down Expand Up @@ -343,3 +350,55 @@ def test_write_failure(self):

with self.assertRaises(ConnectionError):
HttpHandler.write(buf, path)


class TestStreamHandler(unittest.TestCase):
@patch('sys.stdin', io.StringIO('mock stream'))
def test_read_from_stdin(self):
result = StreamHandler().read('foo')
self.assertEqual(result, 'mock stream')

def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
StreamHandler().listdir(None)

@patch('sys.stdout')
def test_write_to_stdout_buffer(self, mock_stdout):
mock_stdout.buffer = io.BytesIO()
StreamHandler().write('mock stream', 'foo')
self.assertEqual(mock_stdout.buffer.getbuffer(), 'mock stream'.encode('utf-8'))

@patch('sys.stdout', new_callable=io.BytesIO)
def test_write_to_stdout(self, mock_stdout):
StreamHandler().write('mock stream', 'foo')
self.assertEqual(mock_stdout.getbuffer(), 'mock stream'.encode('utf-8'))

def test_pretty_path_returns_input_path(self):
'''Should return the input str, which often is the default registered schema "-"'''
self.assertEqual(StreamHandler().pretty_path('foo'), 'foo')


class TestNotebookNodeHandler(unittest.TestCase):
def test_read_notebook_node(self):
input_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
result = NotebookNodeHandler().read(input_nb)
expect = (
'{\n "cells": [\n {\n "cell_type": "code",\n "execution_count": null,'
'\n "metadata": {},\n "outputs": [],\n "source": ['
'\n "print(\'Hello World\')"\n ]\n }\n ],\n "metadata": {'
'\n "kernelspec": {\n "display_name": "Python 3",\n "language": "python",'
'\n "name": "python3"\n }\n },\n "nbformat": 4,\n "nbformat_minor": 2\n}'
)
self.assertEqual(result, expect)

def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
NotebookNodeHandler().listdir('foo')

def test_raises_on_write(self):
with self.assertRaises(PapermillException):
NotebookNodeHandler().write('foo', 'bar')

def test_pretty_path(self):
expect = 'NotebookNode object'
self.assertEqual(NotebookNodeHandler().pretty_path('foo'), expect)
5 changes: 5 additions & 0 deletions papermill/tests/test_parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,8 @@ def test_parameterized_path_with_none_parameters(self):
def test_path_of_none_returns_none(self):
self.assertIsNone(parameterize_path(path=None, parameters={'foo': 'bar'}))
self.assertIsNone(parameterize_path(path=None, parameters=None))

def test_path_of_notebook_node_returns_input(self):
test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
result_nb = parameterize_path(test_nb, parameters=None)
self.assertIs(result_nb, test_nb)