/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.ttl;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.CompositeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyedStateFactory;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.runtime.state.internal.InternalFoldingState;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.runtime.state.internal.InternalReducingState;
import org.apache.flink.runtime.state.internal.InternalValueState;
import org.apache.flink.runtime.state.ttl.TtlAggregateFunction;
import org.apache.flink.runtime.state.ttl.TtlAggregatingState;
import org.apache.flink.runtime.state.ttl.TtlFoldFunction;
import org.apache.flink.runtime.state.ttl.TtlFoldingState;
import org.apache.flink.runtime.state.ttl.TtlListState;
import org.apache.flink.runtime.state.ttl.TtlMapState;
import org.apache.flink.runtime.state.ttl.TtlReduceFunction;
import org.apache.flink.runtime.state.ttl.TtlReducingState;
import org.apache.flink.runtime.state.ttl.TtlStateSnapshotTransformer;
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
import org.apache.flink.runtime.state.ttl.TtlValue;
import org.apache.flink.runtime.state.ttl.TtlValueState;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.SupplierWithException;

public class TtlStateFactory<N, SV, S extends State, IS extends S> {
    private final Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> stateFactories;
    private final TypeSerializer<N> namespaceSerializer;
    private final StateDescriptor<S, SV> stateDesc;
    private final KeyedStateFactory originalStateFactory;
    private final StateTtlConfig ttlConfig;
    private final TtlTimeProvider timeProvider;
    private final long ttl;

    public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled(TypeSerializer<N> namespaceSerializer, StateDescriptor<S, SV> stateDesc, KeyedStateFactory originalStateFactory, TtlTimeProvider timeProvider) throws Exception {
        Preconditions.checkNotNull(namespaceSerializer);
        Preconditions.checkNotNull(stateDesc);
        Preconditions.checkNotNull((Object)originalStateFactory);
        Preconditions.checkNotNull((Object)timeProvider);
        return stateDesc.getTtlConfig().isEnabled() ? super.createState() : originalStateFactory.createInternalState(namespaceSerializer, stateDesc);
    }

    private TtlStateFactory(TypeSerializer<N> namespaceSerializer, StateDescriptor<S, SV> stateDesc, KeyedStateFactory originalStateFactory, TtlTimeProvider timeProvider) {
        this.namespaceSerializer = namespaceSerializer;
        this.stateDesc = stateDesc;
        this.originalStateFactory = originalStateFactory;
        this.ttlConfig = stateDesc.getTtlConfig();
        this.timeProvider = timeProvider;
        this.ttl = this.ttlConfig.getTtl().toMilliseconds();
        this.stateFactories = this.createStateFactories();
    }

    private Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> createStateFactories() {
        return Stream.of(Tuple2.of(ValueStateDescriptor.class, this::createValueState), Tuple2.of(ListStateDescriptor.class, this::createListState), Tuple2.of(MapStateDescriptor.class, this::createMapState), Tuple2.of(ReducingStateDescriptor.class, this::createReducingState), Tuple2.of(AggregatingStateDescriptor.class, this::createAggregatingState), Tuple2.of(FoldingStateDescriptor.class, this::createFoldingState)).collect(Collectors.toMap(t -> (Class)t.f0, t -> (SupplierWithException)t.f1));
    }

    private IS createState() throws Exception {
        SupplierWithException<IS, Exception> stateFactory = this.stateFactories.get(this.stateDesc.getClass());
        if (stateFactory == null) {
            String message = String.format("State %s is not supported by %s", this.stateDesc.getClass(), TtlStateFactory.class);
            throw new FlinkRuntimeException(message);
        }
        return (IS)((State)stateFactory.get());
    }

    private IS createValueState() throws Exception {
        ValueStateDescriptor ttlDescriptor = new ValueStateDescriptor(this.stateDesc.getName(), new TtlSerializer(this.stateDesc.getSerializer()));
        return (IS)new TtlValueState((InternalValueState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, this.stateDesc.getSerializer());
    }

    private <T> IS createListState() throws Exception {
        ListStateDescriptor listStateDesc = (ListStateDescriptor)this.stateDesc;
        ListStateDescriptor ttlDescriptor = new ListStateDescriptor(this.stateDesc.getName(), new TtlSerializer(listStateDesc.getElementSerializer()));
        return (IS)new TtlListState((InternalListState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, listStateDesc.getSerializer());
    }

    private <UK, UV> IS createMapState() throws Exception {
        MapStateDescriptor mapStateDesc = (MapStateDescriptor)this.stateDesc;
        MapStateDescriptor ttlDescriptor = new MapStateDescriptor(this.stateDesc.getName(), mapStateDesc.getKeySerializer(), new TtlSerializer(mapStateDesc.getValueSerializer()));
        return (IS)new TtlMapState((InternalMapState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, mapStateDesc.getSerializer());
    }

    private IS createReducingState() throws Exception {
        ReducingStateDescriptor reducingStateDesc = (ReducingStateDescriptor)this.stateDesc;
        ReducingStateDescriptor ttlDescriptor = new ReducingStateDescriptor(this.stateDesc.getName(), new TtlReduceFunction(reducingStateDesc.getReduceFunction(), this.ttlConfig, this.timeProvider), new TtlSerializer(this.stateDesc.getSerializer()));
        return (IS)new TtlReducingState((InternalReducingState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, this.stateDesc.getSerializer());
    }

    private <IN, OUT> IS createAggregatingState() throws Exception {
        AggregatingStateDescriptor aggregatingStateDescriptor = (AggregatingStateDescriptor)this.stateDesc;
        TtlAggregateFunction ttlAggregateFunction = new TtlAggregateFunction(aggregatingStateDescriptor.getAggregateFunction(), this.ttlConfig, this.timeProvider);
        AggregatingStateDescriptor ttlDescriptor = new AggregatingStateDescriptor(this.stateDesc.getName(), ttlAggregateFunction, new TtlSerializer(this.stateDesc.getSerializer()));
        return (IS)new TtlAggregatingState((InternalAggregatingState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, this.stateDesc.getSerializer(), ttlAggregateFunction);
    }

    private <T> IS createFoldingState() throws Exception {
        FoldingStateDescriptor foldingStateDescriptor = (FoldingStateDescriptor)this.stateDesc;
        Object initAcc = this.stateDesc.getDefaultValue();
        TtlValue<Object> ttlInitAcc = initAcc == null ? null : new TtlValue<Object>(initAcc, Long.MAX_VALUE);
        FoldingStateDescriptor ttlDescriptor = new FoldingStateDescriptor(this.stateDesc.getName(), ttlInitAcc, new TtlFoldFunction(foldingStateDescriptor.getFoldFunction(), this.ttlConfig, this.timeProvider, initAcc), new TtlSerializer(this.stateDesc.getSerializer()));
        return (IS)new TtlFoldingState((InternalFoldingState)this.originalStateFactory.createInternalState(this.namespaceSerializer, ttlDescriptor, this.getSnapshotTransformFactory()), this.ttlConfig, this.timeProvider, this.stateDesc.getSerializer());
    }

    private StateSnapshotTransformer.StateSnapshotTransformFactory<?> getSnapshotTransformFactory() {
        if (!this.ttlConfig.getCleanupStrategies().inFullSnapshot()) {
            return StateSnapshotTransformer.StateSnapshotTransformFactory.noTransform();
        }
        return new TtlStateSnapshotTransformer.Factory(this.timeProvider, this.ttl);
    }

    private static class TtlSerializer<T>
    extends CompositeSerializer<TtlValue<T>> {
        private static final long serialVersionUID = 131020282727167064L;

        TtlSerializer(TypeSerializer<T> userValueSerializer) {
            super(true, new TypeSerializer[]{LongSerializer.INSTANCE, userValueSerializer});
        }

        TtlSerializer(CompositeSerializer.PrecomputedParameters precomputed, TypeSerializer<?> ... fieldSerializers) {
            super(precomputed, fieldSerializers);
        }

        public TtlValue<T> createInstance(Object ... values) {
            Preconditions.checkArgument((values.length == 2 ? 1 : 0) != 0);
            return new TtlValue<Object>(values[1], (Long)values[0]);
        }

        protected void setField(@Nonnull TtlValue<T> v, int index, Object fieldValue) {
            throw new UnsupportedOperationException("TtlValue is immutable");
        }

        protected Object getField(@Nonnull TtlValue<T> v, int index) {
            return index == 0 ? Long.valueOf(v.getLastAccessTimestamp()) : v.getUserValue();
        }

        protected CompositeSerializer<TtlValue<T>> createSerializerInstance(CompositeSerializer.PrecomputedParameters precomputed, TypeSerializer<?> ... originalSerializers) {
            Preconditions.checkNotNull(originalSerializers);
            Preconditions.checkArgument((originalSerializers.length == 2 ? 1 : 0) != 0);
            return new TtlSerializer<T>(precomputed, originalSerializers[1]);
        }
    }
}

