import { z } from "zod";

import { PeriodZod } from "../../../models/primitives";
import env from "../../../services/env";

const allAlgorithmNames = [
  "Prophet",
  "SeasonalRegression",
  "ExponentialSmoothing",
  "Arima",
  "Nbeats",
  "Naive",
] as const;

export const AlgorithmNameZod = z.enum(allAlgorithmNames);
export type AlgorithmName = z.infer<typeof AlgorithmNameZod>;

const standardAlgorithms: AlgorithmName[] = [
  "Prophet",
  "SeasonalRegression",
  "ExponentialSmoothing",
  "Arima",
];

const gpuAlgorithms: AlgorithmName[] = ["Nbeats"];

const extraAlgorithms: AlgorithmName[] = ["Naive"];

export const algorithmNames: AlgorithmName[] = [
  ...standardAlgorithms,
  ...(env.GPU_AVAILABLE ? gpuAlgorithms : []),
  ...extraAlgorithms,
];

export const algorithmsSupportingIFs: readonly AlgorithmName[] = [
  "Prophet",
  "SeasonalRegression",
] as const;

export const partialIFsRules = [
  "Always",
  "Never",
  "50%",
  "Proportional",
] as const;
export const PartialIFsRuleZod = z.enum(partialIFsRules);
export type PartialIFsRule = z.infer<typeof PartialIFsRuleZod>;

// all params must have default values
export const ProphetParametersZod = z.object({
  trend: z.boolean().default(true),
  weekly_seasonality: z.boolean().or(z.literal("auto")).default("auto"),
  yearly_seasonality: z.boolean().or(z.literal("auto")).default("auto"),
  changepoint_range: z.number().min(0).max(1).default(0.8),
  exclude_non_significant_factors: z.boolean().default(true),
  max_pval_inclusion: z.number().min(0).max(1).default(0.05),
});
export type ProphetParameters = z.infer<typeof ProphetParametersZod>;

export const seasonalityTypes = ["Weak", "Medium", "Strong"] as const;
const SeasonalityTypeZod = z.enum(seasonalityTypes);
export type SeasonalityType = z.infer<typeof SeasonalityTypeZod>;

// all params must have default values
export const SeasonalRegressionParametersZod = z.object({
  exclude_non_significant_factors: z.boolean().default(true),
  max_pval_inclusion: z.number().min(0).max(1).default(0.1),
  intersect: z.boolean().default(true),
  linear_trend: z.boolean().default(false),
  quadratic_trend: z.boolean().default(false),
  logarithmic_trend: z.boolean().default(false),
  seasonality: SeasonalityTypeZod.or(z.boolean()).default(false),
  day_in_year: z.boolean().default(false),
  day_in_month: z.boolean().default(false),
  day_in_week: z.boolean().default(false),
  month_in_year: z.boolean().default(false),
  quarters: z.boolean().default(false),
  year: z.boolean().default(false),
  previous_entry: z.boolean().default(false),
  previous_day: z.boolean().default(false),
  previous_week: z.boolean().default(false),
  previous_month: z.boolean().default(false),
  previous_quarter: z.boolean().default(false),
  previous_year: z.boolean().default(false),
});
export type SeasonalRegressionParameters = z.infer<
  typeof SeasonalRegressionParametersZod
>;

// all params must have default values
export const ExponentialSmoothingParametersZod = z.object({
  alpha: z.number().min(0).max(1).nullable().default(null),
  trend: z.boolean().or(z.literal("auto")).default("auto"),
  damped_trend: z.boolean().or(z.literal("auto")).default("auto"),
  seasonal: z.boolean().or(z.literal("auto")).default("auto"),
  period_length: PeriodZod.or(z.string()).default({ length: 7, unit: "D" }),
});
export type ExponentialSmoothingParameters = z.infer<
  typeof ExponentialSmoothingParametersZod
>;

const ArimaManualParameterZod = z
  .number()
  .min(0)
  .max(5)
  .nullable()
  .default(null)
  .catch(null);

// all params must have default values
export const ArimaParametersZod = z.object({
  seasonal: z.boolean().default(false),
  period_length: PeriodZod.or(z.string()).default({ length: 7, unit: "D" }),
  p: ArimaManualParameterZod,
  d: ArimaManualParameterZod,
  q: ArimaManualParameterZod,
  sp: ArimaManualParameterZod,
  sd: ArimaManualParameterZod,
  sq: ArimaManualParameterZod,
});
export type ArimaParameters = z.infer<typeof ArimaParametersZod>;

export const NbeatsParametersZod = z.object({});
export type NbeatsParameters = z.infer<typeof NbeatsParametersZod>;

export const naiveModels = [
  "mean",
  "drift",
  "last_value",
  "daily",
  "weekly",
  "monthly",
  "quarterly",
  "yearly",
] as const;
export const NaiveModelZod = z.enum(naiveModels);
export type NaiveModel = z.infer<typeof NaiveModelZod>;

// all params must have default values
export const NaiveParametersZod = z.object({
  model: NaiveModelZod.array().default(["mean"]),
});
export type NaiveParameters = z.infer<typeof NaiveParametersZod>;

export type AlgorithmParameters =
  | ProphetParameters
  | SeasonalRegressionParameters
  | ExponentialSmoothingParameters
  | ArimaParameters
  | NaiveParameters;

const AlgorithmBetaZod = z.object({
  algorithmConfigId: z.number(),
  algorithmName: AlgorithmNameZod,
  algorithmLabel: z.string(),
  algorithmDescription: z.string(),
  pipelinePosition: z.number(),
  excludeScaledDataCollectionInfluencingFactors: z.boolean(),
  splitWeekdays: z.boolean().default(false),
  applyPartialIfsWhen: PartialIFsRuleZod.default("Proportional"),
  excludePartitionInfluencingFactors: z.boolean().default(false),
  excludeAllInfluencingFactors: z.boolean().default(false),
  useBusinessHourInfluencingFactor: z.boolean().default(false),
  covariateMeasurements: z.number().array().default([]),
});

export const AlgorithmConfigZod = z.discriminatedUnion("algorithmName", [
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("Prophet"),
    algorithmParameters: ProphetParametersZod,
  }),
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("SeasonalRegression"),
    algorithmParameters: SeasonalRegressionParametersZod,
  }),
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("ExponentialSmoothing"),
    algorithmParameters: ExponentialSmoothingParametersZod,
  }),
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("Arima"),
    algorithmParameters: ArimaParametersZod,
  }),
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("Nbeats"),
    algorithmParameters: NbeatsParametersZod,
  }),
  AlgorithmBetaZod.extend({
    algorithmName: z.literal("Naive"),
    algorithmParameters: NaiveParametersZod,
  }),
]);
export type AlgorithmConfig = z.infer<typeof AlgorithmConfigZod>;

const FrozenAlgoBetaZod = AlgorithmBetaZod.omit({
  algorithmConfigId: true,
}).extend({
  frozenAlgorithmConfigId: z.number(),
});
export const FrozenAlgorithmConfigZod = z.discriminatedUnion("algorithmName", [
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("Prophet"),
    algorithmParameters: ProphetParametersZod,
  }),
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("SeasonalRegression"),
    algorithmParameters: SeasonalRegressionParametersZod,
  }),
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("ExponentialSmoothing"),
    algorithmParameters: ExponentialSmoothingParametersZod,
  }),
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("Arima"),
    algorithmParameters: ArimaParametersZod,
  }),
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("Nbeats"),
    algorithmParameters: NbeatsParametersZod,
  }),
  FrozenAlgoBetaZod.extend({
    algorithmName: z.literal("Naive"),
    algorithmParameters: NaiveParametersZod,
  }),
]);
export type FrozenAlgorithmConfig = z.infer<typeof FrozenAlgorithmConfigZod>;
