# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import ast import io import operator import os import sys import textwrap import token import tokenize class Visitor(ast.NodeVisitor): def __init__(self, lines): self._lines = lines self.line_numbers_with_nodes = set() self.line_numbers_with_statements = [] def generic_visit(self, node): if ( hasattr(node, "col_offset") and hasattr(node, "lineno") and node.col_offset == 0 ): self.line_numbers_with_nodes.add(node.lineno) if isinstance(node, ast.stmt): self.line_numbers_with_statements.append(node.lineno) ast.NodeVisitor.generic_visit(self, node) def _tokenize(source): """Tokenize Python source code.""" # Using an undocumented API as the documented one in Python 2.7 does not work as needed # cross-version. if sys.version_info < (3,) and isinstance(source, str): source = source.decode() return tokenize.generate_tokens(io.StringIO(source).readline) def _indent_size(line): for index, char in enumerate(line): if not char.isspace(): return index def _get_global_statement_blocks(source, lines): """Return a list of all global statement blocks. The list comprises of 3-item tuples that contain the starting line number, ending line number and whether the statement is a single line. """ tree = ast.parse(source) visitor = Visitor(lines) visitor.visit(tree) statement_ranges = [] for index, line_number in enumerate(visitor.line_numbers_with_statements): remaining_line_numbers = visitor.line_numbers_with_statements[index + 1 :] end_line_number = ( len(lines) if len(remaining_line_numbers) == 0 else min(remaining_line_numbers) - 1 ) current_statement_is_oneline = line_number == end_line_number if len(statement_ranges) == 0: statement_ranges.append( (line_number, end_line_number, current_statement_is_oneline) ) continue previous_statement = statement_ranges[-1] previous_statement_is_oneline = previous_statement[2] if previous_statement_is_oneline and current_statement_is_oneline: statement_ranges[-1] = previous_statement[0], end_line_number, True else: statement_ranges.append( (line_number, end_line_number, current_statement_is_oneline) ) return statement_ranges def normalize_lines(source): """Normalize blank lines for sending to the terminal. Blank lines within a statement block are removed to prevent the REPL from thinking the block is finished. Newlines are added to separate top-level statements so that the REPL does not think there is a syntax error. """ # Ensure to dedent the code (#2837) lines = textwrap.dedent(source).splitlines(False) # If we have two blank lines, then add two blank lines. # Do not trim the spaces, if we have blank lines with spaces, its possible # we have indented code. if (len(lines) > 1 and len("".join(lines[-2:])) == 0) or source.endswith( ("\n\n", "\r\n\r\n") ): trailing_newline = "\n" * 2 # Find out if we have any trailing blank lines elif len(lines[-1].strip()) == 0 or source.endswith(("\n", "\r\n")): trailing_newline = "\n" else: trailing_newline = "" # Step 1: Remove empty lines. tokens = _tokenize(source) newlines_indexes_to_remove = ( spos[0] for (toknum, tokval, spos, epos, line) in tokens if len(line.strip()) == 0 and token.tok_name[toknum] == "NL" and spos[0] == epos[0] ) for line_number in reversed(list(newlines_indexes_to_remove)): del lines[line_number - 1] # Step 2: Add blank lines between each global statement block. # A consecutive single lines blocks of code will be treated as a single statement, # just to ensure we do not unnecessarily add too many blank lines. source = "\n".join(lines) tokens = _tokenize(source) dedent_indexes = ( spos[0] for (toknum, tokval, spos, epos, line) in tokens if toknum == token.DEDENT and _indent_size(line) == 0 ) global_statement_ranges = _get_global_statement_blocks(source, lines) start_positions = map(operator.itemgetter(0), reversed(global_statement_ranges)) for line_number in filter(lambda x: x > 1, start_positions): lines.insert(line_number - 1, "") sys.stdout.write("\n".join(lines) + trailing_newline) sys.stdout.flush() if __name__ == "__main__": contents = sys.argv[1] try: default_encoding = sys.getdefaultencoding() encoded_contents = contents.encode(default_encoding, "surrogateescape") contents = encoded_contents.decode(default_encoding, "replace") except (UnicodeError, LookupError): pass if isinstance(contents, bytes): contents = contents.decode("utf8") normalize_lines(contents)