diff --git a/ruby/lib/lox/ast_printer.rb b/ruby/lib/lox/ast_printer.rb index 8a89523..2a55dab 100644 --- a/ruby/lib/lox/ast_printer.rb +++ b/ruby/lib/lox/ast_printer.rb @@ -50,7 +50,8 @@ module Lox end def visit_class(stmt) - parenthesize("class #{stmt.name.lexeme}", *stmt.methods) + exprs = [stmt.superclass, *stmt.methods].compact + parenthesize("class #{stmt.name.lexeme}", *exprs) end def visit_function(stmt) diff --git a/ruby/lib/lox/interpreter.rb b/ruby/lib/lox/interpreter.rb index 95c8737..50939ab 100644 --- a/ruby/lib/lox/interpreter.rb +++ b/ruby/lib/lox/interpreter.rb @@ -45,13 +45,21 @@ module Lox end def visit_class(stmt) + superclass = if stmt.superclass + superclass = evaluate(stmt.superclass) + raise RuntimeError.new(stmt.superclass.name, "Superclass must be a class.") unless superclass.is_a?(LoxClass) + superclass + else + nil + end + @env.define(stmt.name.lexeme, nil) methods = stmt.methods.to_h {|method| [method.name.lexeme, Function.new(method, @env, method.name.lexeme == "init")] } - klass = LoxClass.new(stmt.name.lexeme, methods) + klass = LoxClass.new(stmt.name.lexeme, superclass, methods) @env.assign(stmt.name, klass) nil end diff --git a/ruby/lib/lox/lox_class.rb b/ruby/lib/lox/lox_class.rb index 017ea86..274adc8 100644 --- a/ruby/lib/lox/lox_class.rb +++ b/ruby/lib/lox/lox_class.rb @@ -5,8 +5,8 @@ module Lox attr_reader :name - def initialize(name, methods) - @name, @methods = name, methods + def initialize(name, superclass, methods) + @name, @superclass, @methods = name, superclass, methods end def find_method(name) = @methods[name] diff --git a/ruby/lib/lox/parser.rb b/ruby/lib/lox/parser.rb index 10e42b0..a68b6d5 100644 --- a/ruby/lib/lox/parser.rb +++ b/ruby/lib/lox/parser.rb @@ -33,6 +33,14 @@ module Lox def class_decl name = consume!(:IDENTIFIER, "Expect class name.") + + superclass = if match?(:LESS) + consume!(:IDENTIFIER, "Expect superclass name.") + Expr::Variable.new(prev) + else + nil + end + consume!(:LEFT_BRACE, "Expect '{' before class body.") methods = [] @@ -42,7 +50,7 @@ module Lox consume!(:RIGHT_BRACE, "Expect '}' after class body.") - Stmt::Class.new(name, methods) + Stmt::Class.new(name, superclass, methods) end def var_declaration diff --git a/ruby/lib/lox/resolver.rb b/ruby/lib/lox/resolver.rb index 25e14c6..b1c3f07 100644 --- a/ruby/lib/lox/resolver.rb +++ b/ruby/lib/lox/resolver.rb @@ -28,6 +28,12 @@ module Lox def visit_class(stmt) with_current_class(:CLASS) do declare(stmt.name) + define(stmt.name) + + if stmt.superclass + raise ResolverError.new(stmt.superclass.name, "A class can't inherit from itself.") if stmt.name.lexeme == stmt.superclass.name.lexeme + resolve(stmt.superclass) + end with_scope do @scopes.last["this"] = true @@ -36,8 +42,6 @@ module Lox decl = method.name.lexeme == "init" ? :INIT : :METHOD resolve_function(method, decl) end - - define(stmt.name) end end diff --git a/ruby/lib/lox/stmt.rb b/ruby/lib/lox/stmt.rb index 23d31a6..41a69b3 100644 --- a/ruby/lib/lox/stmt.rb +++ b/ruby/lib/lox/stmt.rb @@ -13,7 +13,7 @@ module Lox end stmt :Block, :stmts - stmt :Class, :name, :methods + stmt :Class, :name, :superclass, :methods stmt :Expr, :expr stmt :Function, :name, :params, :body stmt :If, :cond, :then, :else diff --git a/ruby/test/lox/test_interpreter.rb b/ruby/test/lox/test_interpreter.rb index a3f5073..682afc9 100644 --- a/ruby/test/lox/test_interpreter.rb +++ b/ruby/test/lox/test_interpreter.rb @@ -451,6 +451,25 @@ class TestInterpreter < Lox::Test SRC end + def test_inheritance + assert_interpreted "", <<~SRC + class Doughnut {} + class BostonCream < Doughnut {} + SRC + + assert_raises Lox::ResolverError do + interpret("class Oops < Oops {}") + end + + assert_raises Lox::RuntimeError do + interpret(<<~SRC) + var NotAClass = "I am totally not a class"; + + class Subclass < NotAClass {} // ?! + SRC + end + end + private def assert_interpreted(expected, src)