/*
 * Decompiled with CFR 0.152.
 */
package smile.datasets;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.stream.IntStream;
import org.apache.commons.csv.CSVFormat;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.io.Paths;
import smile.io.Read;

public record USPS(DataFrame train, DataFrame test, Formula formula) {
    private static final StructType schema;

    public USPS() throws IOException {
        this(Paths.getTestData("usps/zip.train"), Paths.getTestData("usps/zip.test"));
    }

    public USPS(Path trainDataPath, Path testDataPath) throws IOException {
        this(USPS.load(trainDataPath), USPS.load(testDataPath), Formula.lhs("class"));
    }

    private static DataFrame load(Path path) throws IOException {
        CSVFormat format = CSVFormat.Builder.create().setDelimiter(' ').get();
        return Read.csv(path, format, schema);
    }

    public double[][] x() {
        return this.formula.x(this.train).toArray(false, CategoricalEncoder.DUMMY, new String[0]);
    }

    public int[] y() {
        return this.formula.y(this.train).toIntArray();
    }

    public double[][] testx() {
        return this.formula.x(this.test).toArray(false, CategoricalEncoder.DUMMY, new String[0]);
    }

    public int[] testy() {
        return this.formula.y(this.test).toIntArray();
    }

    static {
        ArrayList<StructField> fields = new ArrayList<StructField>();
        fields.add(new StructField("class", DataTypes.ByteType));
        IntStream.range(1, 257).forEach(i -> fields.add(new StructField("V" + i, DataTypes.FloatType)));
        schema = new StructType(fields);
    }
}

