#!/usr/bin/env ruby
# frozen_string_literal: true

$:.unshift(File.expand_path("../lib", __dir__))
require "prism"

module Prism
  class CLI
    def run(argv)
      case argv.shift
      when "benchmark"  then benchmark(argv)
      when "bundle"     then bundle(argv)
      when "console"    then console
      when "dot"        then dot(argv)
      when "encoding"   then encoding(argv)
      when "error"      then error(argv)
      when "lex"        then lex(argv)
      when "locals"     then locals(argv)
      when "parse"      then parse(argv)
      when "parser"     then parser(argv)
      when "ripper"     then ripper(argv)
      when "rubyparser" then rubyparser(argv)
      when "repl"       then repl
      else
        puts <<~TXT
          Usage:
            bin/prism benchmark [source_file]
            bin/prism bundle [...]
            bin/prism console
            bin/prism dot [source]
            bin/prism encoding [encoding]
            bin/prism error [name] [source]
            bin/prism lex [source]
            bin/prism locals [source]
            bin/prism parse [source]
            bin/prism parser [source]
            bin/prism ripper [source]
            bin/prism rubyparser [source]
            bin/prism repl
        TXT
      end
    end

    private

    ############################################################################
    # Commands
    ############################################################################

    # bin/prism benchmark [source_file]
    def benchmark(argv)
      require "benchmark/ips"
      require "parser/current"
      require "ripper"
      require "ruby_parser"

      filepath = argv.fetch(0) { File.expand_path("../lib/prism/node.rb", __dir__) }
      source = File.read(filepath)

      Benchmark.ips do |x|
        x.report("Prism") do
          Prism.parse(source, filepath: filepath)
        end

        x.report("Ripper::SexpBuilder") do
          Ripper.sexp_raw(source, filepath)
        end

        x.report("Prism::Translation::Ripper::SexpBuilder") do
          Prism::Translation::Ripper.sexp_raw(source, filepath)
        end

        x.report("Parser::CurrentRuby") do
          Parser::CurrentRuby.parse(source, filepath)
        end

        x.report("Prism::Translation::Parser") do
          Prism::Translation::Parser.parse(source, filepath)
        end

        x.report("RubyParser") do
          RubyParser.new.parse(source, filepath)
        end

        x.report("Prism::Translation::RubyParser") do
          Prism::Translation::RubyParser.new.parse(source, filepath)
        end

        x.compare!
      end
    end

    # bin/prism bundle [...]
    def bundle(argv)
      [
        ["3.1", ["3.1"]],
        ["3.2", ["3.2"]],
        ["3.3", ["3.3"]],
        ["3.4", ["3.4", "typecheck"]],
        ["3.5", ["3.5"]],
        ["jruby", ["jruby"]],
        ["truffleruby", ["truffleruby"]]
      ].each do |ruby_version, gemfiles|
        gemfiles.each do |gemfile|
          system(
            { "BUNDLE_GEMFILE" => File.expand_path("../gemfiles/#{gemfile}/Gemfile", __dir__) },
            "chruby-exec", ruby_version, "--", "bundle", *argv,
            exception: true,
            chdir: File.expand_path("../", __dir__)
          )
        rescue RuntimeError => error
          warn("\e[31;m#{ruby_version} (#{gemfile}): #{error.message}\e[m")
        end
      end
    end

    # bin/prism console
    def console
      require "irb"
      IRB.start(__FILE__)
    end

    # bin/prism dot [source]
    def dot(argv)
      result = parse_source(argv)

      node = result.value
      path =
        if argv.first&.match?(/\A[\w\.]+\z/)
          argv.first.split(".")
        else
          %w[statements body first]
        end

      path.each do |field|
        node = node.public_send(field)
      end

      File.write(
        "out.svg",
        IO.popen("dot -Tsvg", "w+") do |file|
          file.write(node.to_dot)
          file.close_write
          file.read
        end
      )
    end

    # bin/prism encoding [encoding]
    def encoding(argv)
      found = Encoding.find(argv.shift)
      found = Encoding::ASCII_8BIT if found == Encoding::US_ASCII

      if !found.ascii_compatible?
        warn("Encoding `#{found.name}' is not ASCII compatible")
        exit(1)
      end

      lookup_table(found)
      unicode_lists(found) if found == Encoding::UTF_8 || found == Encoding::UTF8_MAC
    end

    # bin/prism error [name] [source]
    def error(argv)
      name = argv.shift

      source = nil
      filepath = File.expand_path("../test/prism/errors/#{name}.txt", __dir__)

      if argv.empty?
        raise "Expected #{filepath} to exist" unless File.file?(filepath)

        source = File.read(filepath, binmode: true, external_encoding: Encoding::UTF_8)
        source = source.lines.grep_v(/^\s*\^/).join.gsub(/\n*\z/, "")
      else
        if File.file?(filepath)
          counter = 1

          begin
            current = "#{File.dirname(filepath)}/#{File.basename(filepath, ".txt")}_#{counter += 1}.txt"
          end while File.file?(current)

          filepath = current
        end

        source, _ = read_source(argv)
      end

      result = Prism.parse(source)
      raise "Expected #{source.inspect} to have errors" if result.success?

      File.write(filepath, result.errors_format)
    end

    # bin/prism lex [source]
    def lex(argv)
      source, filepath = read_source(argv)

      ripper_value =
        begin
          Prism.lex_ripper(source)
        rescue ArgumentError, SyntaxError
          # If Ripper raises a syntax error, we want to continue as if it didn't
          # return any tokens at all. prism won't raise a syntax error, so it's
          # nicer to still be able to see the tokens that prism generated.
          []
        end

      prism_compat = Prism.lex_compat(source, filepath: filepath)
      prism = Prism.lex(source, filepath: filepath)

      if prism_compat.failure?
        puts "Errors lexing:"

        prism_compat.errors.map do |error|
          print "- [#{error.location.start_line},#{error.location.start_column}-"
          print "#{error.location.end_line},#{error.location.end_column}] "
          puts "\e[1;31m#{error.message}\e[0m"
        end

        puts "\n"
      end

      pattern = "%-64s %-64s %-64s"
      puts pattern % ["Ripper lex", "Prism compat lex", "Prism lex"]
      puts pattern % ["-" * 64, "-" * 64, "-" * 64]

      prism_compat_value = prism_compat.value
      prism_value = prism.value

      [ripper_value.length, prism_compat_value.length, prism_value.length].max.times do |index|
        parts = [ripper_value[index], prism_compat_value[index], nil]

        unless prism_value[index].nil?
          prism_token = prism_value[index][0]
          location = prism_token.location
          parts[2] = [[location.start_line, location.start_column], prism_token.type, location.slice]
        end

        if parts[0] != parts[1]
          puts "\033[1;31m#{pattern}\033[0m" % parts.map(&:inspect)
        elsif ENV["VERBOSE"]
          puts "\033[38;5;102m#{pattern}\033[0m" % parts.map(&:inspect)
        end
      end
    end

    # bin/prism locals [source]
    def locals(argv)
      source, _ = read_source(argv)

      puts "CRuby:"
      p Debug.cruby_locals(source)

      puts "Prism:"
      p Debug.prism_locals(source)
    end

    # bin/prism parse [source]
    def parse(argv)
      result = parse_source(argv)

      parts = {}
      parts["Comments"] = result.comments if result.comments.any?
      parts["Magic comments"] = result.magic_comments if result.magic_comments.any?
      parts["Warnings"] = result.warnings if result.warnings.any?
      parts["Errors"] = result.errors if result.errors.any?
      parts["DATA"] = result.data_loc if result.data_loc

      if parts.empty?
        puts result.value.inspect
      else
        parts["AST"] = result.value
        parts.each_with_index do |(key, value), index|
          puts if index > 0
          puts "#{key}:"
          pp value
        end
      end
    end

    # bin/prism parser [source]
    def parser(argv)
      require "parser/ruby34"
      source, filepath = read_source(argv)

      buffer = Parser::Source::Buffer.new(filepath, 1)
      buffer.source = source

      puts "Parser:"
      parser_parse(Parser::Ruby34.new, buffer)

      puts "Prism:"
      parser_parse(Prism::Translation::Parser34.new, buffer)
    end

    # bin/prism ripper [source]
    def ripper(argv)
      require "ripper"
      source, _ = read_source(argv)

      ripper = Ripper.sexp_raw(source)
      prism = Prism::Translation::Ripper.sexp_raw(source)

      puts "Ripper:"
      pp ripper

      puts "Prism:"
      pp prism

      puts "Output is #{ripper == prism ? "" : "not "}identical"
    end

    # bin/prism rubyparser [source]
    def rubyparser(argv)
      require "ruby_parser"
      source, filepath = read_source(argv)

      ruby_parser = RubyParser.new.parse(source, filepath)
      prism = Prism::Translation::RubyParser.parse(source, filepath)

      puts "RubyParser:"
      pp ruby_parser

      puts "Prism:"
      pp prism
    end

    # bin/prism repl
    def repl
      loop do
        print "Prism> "

        input = gets
        break if input.nil?

        result = Prism.parse(input.chomp)

        statements_node = result.value.statements
        statements = statements_node.body

        result.errors.each { |e| p(e) }

        p(statements.one? ? statements.first : statements)
      end
    end

    ############################################################################
    # Helpers
    ############################################################################

    # Parse a Parser::Source::Buffer with a given parser.
    def parser_parse(parser, buffer)
      diagnostics = []
      parser.diagnostics.consumer = -> (diagnostic) { diagnostics << diagnostic }

      parser_ast, _, parser_tokens = parser.tokenize(buffer, true)
      pp diagnostics if diagnostics.any?
      pp parser_ast
      pp parser_tokens
    end

    # Generate the list of values that will be used in a lookup table for a
    # given encoding.
    def lookup_table_values(encoding)
      (0...256).each_slice(16).map.with_index do |slice, row_index|
        slice.map do |codepoint|
          character = codepoint.chr(encoding)

          values = 0
          values |= (1 << 0) if character.match?(/[[:alpha:]]/)
          values |= (1 << 1) if character.match?(/[[:alnum:]]/)
          values |= (1 << 2) if character.match?(/[[:upper:]]/)
          values
        rescue RangeError
          0
        end
      end
    end

    # Generate a lookup table for a given encoding.
    def lookup_table(encoding)
      encoding_values = lookup_table_values(encoding)
      if encoding_values == lookup_table_values(Encoding::US_ASCII)
        puts "static const uint8_t pm_encoding_ascii_table[256] = {"
      else
        puts "static const uint8_t pm_encoding_#{encoding.name.downcase}_table[256] = {"
      end

      puts "//  #{(0...16).map { |value| value.to_s(16).upcase }.join("  ")}"
      encoding_values.each_with_index do |row, row_index|
        puts "    #{row.join(", ")}, // #{row_index.to_s(16).upcase}x"
      end
      puts "};"
    end

    # Generate lists of unicode codepoints for a given encoding.
    def unicode_lists(encoding)
      encoding = Encoding::UTF_8
      range = (0x100..0xD7FF).to_a.concat((0xE000..0x10FFFF).to_a)

      { alpha: /[[:alpha:]]/, alnum: /[[:alnum:]]/, isupper: /[[:upper:]]/ }.map do |kind, regex|
        codepoints = range.select { |codepoint| codepoint.chr(encoding).match?(regex) }

        previous = nil
        groups =
          codepoints.slice_before do |codepoint|
            (!previous.nil? && (codepoint - previous) != 1).tap { previous = codepoint }
          end

        matched =
          groups.flat_map do |group|
            ["0x#{group.first.to_s(16).upcase}", "0x#{group.last.to_s(16).upcase}"]
          end

        puts "\n#define UNICODE_#{kind.upcase}_CODEPOINTS_LENGTH #{matched.length}"
        puts "unicode_codepoint_t unicode_#{kind}_codepoints[UNICODE_#{kind.upcase}_CODEPOINTS_LENGTH] = {"
        matched.each_slice(2) { |slice| puts "    #{slice.join(", ")}," }
        puts "};"
      end
    end

    # Parse the source code indicated by the command-line arguments.
    def parse_source(argv)
      command_line = +""
      command_line << argv.shift[1] while argv.first&.match?(/^-[alnpx]$/)

      case argv.first
      when "-e"
        command_line << argv.shift[1]
        Prism.parse(argv.shift, command_line: command_line)
      when nil
        Prism.parse_file("test.rb", command_line: command_line)
      else
        Prism.parse_file(argv.shift, command_line: command_line)
      end
    end

    # Get the source code indicated by the command-line arguments.
    def read_source(argv)
      case argv.first
      when "-e"
        argv.shift
        [argv.shift, "-e"]
      when nil
        [File.read("test.rb"), "test.rb"]
      else
        filepath = argv.shift
        [File.read(filepath), filepath]
      end
    end
  end
end

Prism::CLI.new.run(ARGV)
