Skip to content

Commit

Permalink
Adds input of NotebookNode object directly (#670)
Browse files Browse the repository at this point in the history
* ENH: Allow passing notebooknode to execute

* ENH: Clean up parameterize_path and add to docstring

* ENH: Make file extension check optional

* DOC: Clearer error message

* DOC: Add to execute docstring

* DOC: Add docstring for handler classes

* TST: Adds test for parameterize path of NotebookNode

* TST: Update test for execute with notebook node

* TST: Adds tests for StreamHandler

* TST: Adds notebook node handler tests

* TST: Adds test fixture for notebooknode handler

* LNT: Fix quotes

* TST: Adds test for getting notebooknode handler from papermillio

Co-authored-by: dcnadler <nadler@amyris.com>
Co-authored-by: Rohit Sanjay <sanjay.rohit2@gmail.com>
  • Loading branch information
3 people committed Aug 12, 2022
1 parent 9f02383 commit a3f530e
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 47 deletions.
4 changes: 2 additions & 2 deletions papermill/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def execute_notebook(
Parameters
----------
input_path : str or Path
Path to input notebook
input_path : str or Path or nbformat.NotebookNode
Path to input notebook or NotebookNode object of notebook
output_path : str or Path or None
Path to save executed notebook. If None, no file will be saved
parameters : dict, optional
Expand Down
112 changes: 73 additions & 39 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,50 +96,14 @@ def __init__(self):
self.reset()

def read(self, path, extensions=['.ipynb', '.json']):
if path == '-':
return sys.stdin.read()

if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions
):
warnings.warn(
"The specified input file ({}) does not end in one of {}".format(path, extensions)
)
# Handle https://github.com/nteract/papermill/issues/317
notebook_metadata = self.get_handler(path).read(path)
notebook_metadata = self.get_handler(path, extensions).read(path)
if isinstance(notebook_metadata, (bytes, bytearray)):
return notebook_metadata.decode('utf-8')
return notebook_metadata

def write(self, buf, path, extensions=['.ipynb', '.json']):
if path is None:
return
if path == '-':
try:
return sys.stdout.buffer.write(buf.encode('utf-8'))
except AttributeError:
# Originally required by https://github.com/nteract/papermill/issues/420
# Support Buffer.io objects
return sys.stdout.write(buf.encode('utf-8'))

return sys.stdout.buffer.write(buf.encode('utf-8'))

# Usually no return object here
if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext) for ext in extensions
):
warnings.warn(
"The specified output file ({}) does not end in one of {}".format(path, extensions)
)
return self.get_handler(path).write(buf, path)
return self.get_handler(path, extensions).write(buf, path)

def listdir(self, path):
return self.get_handler(path).listdir(path)
Expand All @@ -159,10 +123,44 @@ def register_entry_points(self):
for entrypoint in entrypoints.get_group_all("papermill.io"):
self.register(entrypoint.name, entrypoint.load())

def get_handler(self, path):
def get_handler(self, path, extensions=None):
'''Get I/O Handler based on a notebook path
Parameters
----------
path : str or nbformat.NotebookNode or None
extensions : list of str, optional
Required file extension options for the path (if path is a string), which
will log a warning if there is no match. Defaults to None, which does not
check for any extensions
Raises
------
PapermillException: If a valid I/O handler could not be found for the input path
Returns
-------
I/O Handler
'''
if path is None:
return NoIOHandler()

if isinstance(path, nbformat.NotebookNode):
return NotebookNodeHandler()

if extensions:
if not fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*.*'):
warnings.warn(
"the file is not specified with any extension : " + os.path.basename(path)
)
elif not any(
fnmatch.fnmatch(os.path.basename(path).split('?')[0], '*' + ext)
for ext in extensions
):
warnings.warn(
"The specified file ({}) does not end in one of {}".format(path, extensions)
)

local_handler = None
for scheme, handler in self._handlers:
if scheme == 'local':
Expand Down Expand Up @@ -411,6 +409,41 @@ def pretty_path(self, path):
return path


class StreamHandler(object):
'''Handler for Stdin/Stdout streams'''
def read(self, path):
return sys.stdin.read()

def listdir(self, path):
raise PapermillException('listdir is not supported by Stream Handler')

def write(self, buf, path):
try:
return sys.stdout.buffer.write(buf.encode('utf-8'))
except AttributeError:
# Originally required by https://github.com/nteract/papermill/issues/420
# Support Buffer.io objects
return sys.stdout.write(buf.encode('utf-8'))

def pretty_path(self, path):
return path


class NotebookNodeHandler(object):
'''Handler for input_path of nbformat.NotebookNode object'''
def read(self, path):
return nbformat.writes(path)

def listdir(self, path):
raise PapermillException('listdir is not supported by NotebookNode Handler')

def write(self, buf, path):
raise PapermillException('write is not supported by NotebookNode Handler')

def pretty_path(self, path):
return 'NotebookNode object'


class NoIOHandler(object):
'''Handler for output_path of None - intended to not write anything'''

Expand Down Expand Up @@ -448,6 +481,7 @@ class NoDatesSafeLoader(yaml.SafeLoader):
papermill_io.register("hdfs://", HDFSHandler())
papermill_io.register("http://github.com/", GithubHandler())
papermill_io.register("https://github.com/", GithubHandler())
papermill_io.register("-", StreamHandler())
papermill_io.register_entry_points()


Expand Down
9 changes: 5 additions & 4 deletions papermill/parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def parameterize_path(path, parameters):
Parameters
----------
path : string or None
Path with optional parameters, as a python format string
path : string or nbformat.NotebookNode or None
Path with optional parameters, as a python format string. If path is a NotebookNode
or None, the path is returned without modification
parameters : dict or None
Arbitrary keyword arguments to fill in the path
"""
if path is None:
return
if path is None or isinstance(path, nbformat.NotebookNode):
return path

if parameters is None:
parameters = {}
Expand Down
20 changes: 20 additions & 0 deletions papermill/tests/notebooks/test_notebooknode_io.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": ["print('Hello World')"]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
16 changes: 16 additions & 0 deletions papermill/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from pathlib import Path

import nbformat
from nbformat import validate

from .. import engines, translators
Expand Down Expand Up @@ -436,3 +437,18 @@ def test_custom_kernel_name_and_language(self, translate_parameters, execute_man
)
self.assertEqual(execute_managed_notebook.call_args[0], (ANY, "my_custom_kernel"))
self.assertEqual(translate_parameters.call_args[0], (ANY, 'my_custom_language', {"msg": "fake msg"}, ANY))


class TestNotebookNodeInput(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
self.result_path = os.path.join(self.test_dir.name, 'output.ipynb')

def tearDown(self):
self.test_dir.cleanup()

def test_notebook_node_input(self):
input_nb = nbformat.read(get_notebook_path('simple_execute.ipynb'), as_version=4)
execute_notebook(input_nb, self.result_path, {'msg': 'Hello'})
test_nb = nbformat.read(self.result_path, as_version=4)
self.assertEqual(test_nb.metadata.papermill.parameters, {'msg': 'Hello'})
63 changes: 61 additions & 2 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
LocalHandler,
NoIOHandler,
ADLHandler,
NotebookNodeHandler,
StreamHandler,
PapermillIO,
read_yaml_file,
papermill_io,
local_file_io_cwd,
)
from ..exceptions import PapermillException
from . import get_notebook_path

FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')

Expand Down Expand Up @@ -94,6 +97,10 @@ def test_get_local_handler(self):
def test_get_no_io_handler(self):
self.assertIsInstance(self.papermill_io.get_handler(None), NoIOHandler)

def test_get_notebook_node_handler(self):
test_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler)

def test_entrypoint_register(self):

fake_entrypoint = Mock(load=Mock())
Expand Down Expand Up @@ -155,7 +162,7 @@ def test_read_yaml_with_invalid_file_extension(self):
def test_read_stdin(self):
file_content = u'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
with patch('sys.stdin', io.StringIO(file_content)):
self.assertEqual(self.papermill_io.read("-"), file_content)
self.assertEqual(self.old_papermill_io.read("-"), file_content)

def test_listdir(self):
self.assertEqual(self.papermill_io.listdir("fake/path"), ["fake", "contents"])
Expand All @@ -178,7 +185,7 @@ def test_write_stdout(self):
file_content = u'Τὴ γλῶσσα μοῦ ἔδωσαν ἑλληνικὴ'
out = io.BytesIO()
with patch('sys.stdout', out):
self.papermill_io.write(file_content, "-")
self.old_papermill_io.write(file_content, "-")
self.assertEqual(out.getvalue(), file_content.encode('utf-8'))

def test_pretty_path(self):
Expand Down Expand Up @@ -343,3 +350,55 @@ def test_write_failure(self):

with self.assertRaises(ConnectionError):
HttpHandler.write(buf, path)


class TestStreamHandler(unittest.TestCase):
@patch('sys.stdin', io.StringIO('mock stream'))
def test_read_from_stdin(self):
result = StreamHandler().read('foo')
self.assertEqual(result, 'mock stream')

def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
StreamHandler().listdir(None)

@patch('sys.stdout')
def test_write_to_stdout_buffer(self, mock_stdout):
mock_stdout.buffer = io.BytesIO()
StreamHandler().write('mock stream', 'foo')
self.assertEqual(mock_stdout.buffer.getbuffer(), 'mock stream'.encode('utf-8'))

@patch('sys.stdout', new_callable=io.BytesIO)
def test_write_to_stdout(self, mock_stdout):
StreamHandler().write('mock stream', 'foo')
self.assertEqual(mock_stdout.getbuffer(), 'mock stream'.encode('utf-8'))

def test_pretty_path_returns_input_path(self):
'''Should return the input str, which often is the default registered schema "-"'''
self.assertEqual(StreamHandler().pretty_path('foo'), 'foo')


class TestNotebookNodeHandler(unittest.TestCase):
def test_read_notebook_node(self):
input_nb = nbformat.read(get_notebook_path('test_notebooknode_io.ipynb'), as_version=4)
result = NotebookNodeHandler().read(input_nb)
expect = (
'{\n "cells": [\n {\n "cell_type": "code",\n "execution_count": null,'
'\n "metadata": {},\n "outputs": [],\n "source": ['
'\n "print(\'Hello World\')"\n ]\n }\n ],\n "metadata": {'
'\n "kernelspec": {\n "display_name": "Python 3",\n "language": "python",'
'\n "name": "python3"\n }\n },\n "nbformat": 4,\n "nbformat_minor": 2\n}'
)
self.assertEqual(result, expect)

def test_raises_on_listdir(self):
with self.assertRaises(PapermillException):
NotebookNodeHandler().listdir('foo')

def test_raises_on_write(self):
with self.assertRaises(PapermillException):
NotebookNodeHandler().write('foo', 'bar')

def test_pretty_path(self):
expect = 'NotebookNode object'
self.assertEqual(NotebookNodeHandler().pretty_path('foo'), expect)
5 changes: 5 additions & 0 deletions papermill/tests/test_parameterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,8 @@ def test_parameterized_path_with_none_parameters(self):
def test_path_of_none_returns_none(self):
self.assertIsNone(parameterize_path(path=None, parameters={'foo': 'bar'}))
self.assertIsNone(parameterize_path(path=None, parameters=None))

def test_path_of_notebook_node_returns_input(self):
test_nb = load_notebook_node(get_notebook_path("simple_execute.ipynb"))
result_nb = parameterize_path(test_nb, parameters=None)
self.assertIs(result_nb, test_nb)

0 comments on commit a3f530e

Please sign in to comment.