Class: NanoGPT::TrainConfig

Inherits:
Object
  • Object
show all
Defined in:
lib/nano_gpt/train_config.rb

Overview

Configuration system for training and sampling Supports JSON config files with command-line overrides

Priority (highest to lowest):

1. Command-line arguments (--key=value)
2. JSON config file (--config=path.json)
3. Default values

Usage:

config = TrainConfig.load(ARGV)
config[:learning_rate]  # => 0.001

Constant Summary collapse

DEFAULTS =

Defaults match bin/train exactly

{
  # I/O
  out_dir: "out-shakespeare-char",
  eval_interval: 250,
  log_interval: 10,
  eval_iters: 200,
  eval_only: false,
  always_save_checkpoint: false,
  init_from: "scratch",  # 'scratch' or 'resume'

  # Data
  dataset: "shakespeare_char",
  batch_size: 64,
  block_size: 256,
  gradient_accumulation_steps: 1,

  # Model
  n_layer: 6,
  n_head: 6,
  n_embd: 384,
  dropout: 0.2,
  bias: false,

  # Optimizer
  learning_rate: 1e-3,
  weight_decay: 1e-1,
  beta1: 0.9,
  beta2: 0.99,
  grad_clip: 1.0,

  # LR scheduler
  decay_lr: true,
  warmup_iters: 100,
  lr_decay_iters: 5000,
  min_lr: 1e-4,

  # Training
  max_iters: 5000,

  # System
  device: "auto"
}.freeze

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(values = {}) ⇒ TrainConfig

Returns a new instance of TrainConfig.



65
66
67
# File 'lib/nano_gpt/train_config.rb', line 65

def initialize(values = {})
  @values = DEFAULTS.merge(values)
end

Instance Attribute Details

#valuesObject (readonly)

Returns the value of attribute values.



63
64
65
# File 'lib/nano_gpt/train_config.rb', line 63

def values
  @values
end

Class Method Details

.load(args) ⇒ Object

Load config from command-line args Supports:

--config=path/to/config.json  (load JSON file)
--key=value                   (override specific values)


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# File 'lib/nano_gpt/train_config.rb', line 85

def self.load(args)
  config = new

  # First pass: find and load JSON config file
  config_file = nil
  args.each do |arg|
    if arg.start_with?("--config=")
      config_file = arg.split("=", 2).last
      break
    end
  end

  if config_file
    config.load_json(config_file)
  end

  # Second pass: apply command-line overrides
  args.each do |arg|
    next unless arg.start_with?("--") && arg.include?("=")
    next if arg.start_with?("--config=")

    key, val = arg[2..].split("=", 2)
    key = key.to_sym

    unless config.values.key?(key)
      puts "Warning: Unknown config key: #{key}"
      next
    end

    config[key] = parse_value(val, config[key])
    puts "Override: #{key} = #{config[key]}"
  end

  config
end

.parse_value(val, existing) ⇒ Object



150
151
152
153
154
155
156
157
# File 'lib/nano_gpt/train_config.rb', line 150

def self.parse_value(val, existing)
  case existing
  when Integer then val.to_i
  when Float then val.to_f
  when TrueClass, FalseClass then val.downcase == "true"
  else val
  end
end

Instance Method Details

#[](key) ⇒ Object



69
70
71
# File 'lib/nano_gpt/train_config.rb', line 69

def [](key)
  @values[key.to_sym]
end

#[]=(key, value) ⇒ Object



73
74
75
# File 'lib/nano_gpt/train_config.rb', line 73

def []=(key, value)
  @values[key.to_sym] = value
end

#load_json(path) ⇒ Object

Load values from JSON file



122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# File 'lib/nano_gpt/train_config.rb', line 122

def load_json(path)
  unless File.exist?(path)
    raise "Config file not found: #{path}"
  end

  json = JSON.parse(File.read(path))
  puts "Loaded config from #{path}"

  json.each do |key, val|
    key = key.to_sym
    unless @values.key?(key)
      puts "Warning: Unknown config key in JSON: #{key}"
      next
    end
    @values[key] = val
  end

  self
end

#save_json(path) ⇒ Object

Save current config to JSON file



143
144
145
146
# File 'lib/nano_gpt/train_config.rb', line 143

def save_json(path)
  File.write(path, JSON.pretty_generate(@values))
  puts "Saved config to #{path}"
end

#to_hObject



77
78
79
# File 'lib/nano_gpt/train_config.rb', line 77

def to_h
  @values.dup
end