Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Optimizer parameters public to enable dynamic learning rate #87

Merged
merged 1 commit into from
May 20, 2024

Conversation

kemchenj
Copy link
Contributor

During my use of mlx-swift, I found that the optimizer in mlx-swift does not expose the learning rate property, making it impossible to dynamically adjust the learning rate during training.

I believe this is a significant functionality gap. Therefore, this PR changes all relevant properties of the Optimizer to public.

var dampening: Float = 0
var nesterov = false
/// The learning rate
public var learningRate: Float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look good to me, but I wonder about changing non-MLXArray values during the run and interaction with grad()

@awni will this work ok?

I realized there was a scheduler piece that happened on the python side and we don't have an equivalent here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine to change non-mlx arrays. The main (and I think only time) you need to be careful there is when using mx.compile. In Python we detect if constants (non arrays) change and recompile, but I don't think we do it in Swift yet.

In Python everything is public by default so this kind of mimics that. I would say its ok to do it here. We do have schedulers in Python which is the typical way for updating a learning rate (and other variables). To the extent that the interface has to change for the schedulers that may be something to watch out for (e.g. if we add schedulers will we break the API that this exposes?).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the extent that the interface has to change for the schedulers that may be something to watch out for (e.g. if we add schedulers will we break the API that this exposes?).

Possibly, but maybe not. It looks like the optimizer base gets a step so we could do something like this:

public var step: MLXArray

public var learningRate: Float {
    get {
        if let scheduledLearningRate {
            return scheduledLearningRate(step)
        } else {
            return _learningRate
        }
    }
    set {
        if scheduledLearningRate != nil {
            fatalError("cannot set learningRate with scheduler")
        }
        _learningRate = newValue
     }
}

private var _learningRate: Float = 0

public var scheduledLearningRate: ((MLXArray) -> MLXArray)?

That would keep the same API while adding an optional scheduled parameter.

Copy link
Collaborator

@davidkoski davidkoski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for these changes!

@davidkoski davidkoski merged commit 83efa17 into ml-explore:main May 20, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants