/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.linear;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.Option;
import java.util.logging.Logger;
import org.tribuo.classification.ClassificationOptions;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.objectives.Hinge;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.math.optimisers.GradientOptimiserOptions;

public class LinearSGDOptions
implements ClassificationOptions<LinearSGDTrainer> {
    private static final Logger logger = Logger.getLogger(LinearSGDOptions.class.getName());
    public GradientOptimiserOptions sgoOptions;
    @Option(longName="sgd-epochs", usage="Number of SGD epochs. Defaults to 5.")
    public int sgdEpochs = 5;
    @Option(longName="sgd-objective", usage="Loss function. Defaults to LOG.")
    public LossEnum sgdObjective = LossEnum.LOG;
    @Option(longName="sgd-logging-interval", usage="Log the objective after <int> examples. Defaults to 100.")
    public int sgdLoggingInterval = 100;
    @Option(longName="sgd-minibatch-size", usage="Minibatch size. Defaults to 1.")
    public int sgdMinibatchSize = 1;
    @Option(longName="sgd-seed", usage="Sets the random seed for the LinearSGDTrainer.")
    public long sgdSeed = 12345L;

    public LabelObjective getLoss() {
        switch (this.sgdObjective) {
            case HINGE: {
                return new Hinge();
            }
            case LOG: {
                return new LogMulticlass();
            }
        }
        throw new ArgumentException("sgd-objective", "Unknown loss function " + (Object)((Object)this.sgdObjective));
    }

    public LinearSGDTrainer getTrainer() {
        logger.info(String.format("Set logging interval to %d", this.sgdLoggingInterval));
        return new LinearSGDTrainer(this.getLoss(), this.sgoOptions.getOptimiser(), this.sgdEpochs, this.sgdLoggingInterval, this.sgdMinibatchSize, this.sgdSeed);
    }

    public static enum LossEnum {
        HINGE,
        LOG;

    }
}

