package impl

import (
	"testing"

	"github.com/golang/protobuf/proto"
	"github.com/stretchr/testify/assert"

	"github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models"
	"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
	"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
	"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
	"github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common"
)

var launchPlanIdentifier = &core.Identifier{
	ResourceType: core.ResourceType_LAUNCH_PLAN,
	Project:      "project",
	Domain:       "domain",
	Name:         "name",
	Version:      "version",
}

var inputs = core.ParameterMap{
	Parameters: map[string]*core.Parameter{
		"foo": {
			Var: &core.Variable{
				Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}},
			},
			Behavior: &core.Parameter_Default{
				Default: coreutils.MustMakeLiteral("foo-value"),
			},
		},
	},
}
var outputs = core.VariableMap{
	Variables: map[string]*core.Variable{
		"foo": {
			Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}},
		},
	},
}

func getProviderForTest(t *testing.T) common.InterfaceProvider {
	launchPlanStatus := &admin.LaunchPlanClosure{
		ExpectedInputs:  &inputs,
		ExpectedOutputs: &outputs,
	}
	bytes, _ := proto.Marshal(launchPlanStatus)
	provider, err := NewLaunchPlanInterfaceProvider(
		models.LaunchPlan{
			Closure: bytes,
		}, launchPlanIdentifier)
	if err != nil {
		t.Fatalf("Failed to initialize LaunchPlanInterfaceProvider for test with err %v", err)
	}
	return provider
}

func TestGetId(t *testing.T) {
	provider := getProviderForTest(t)
	assert.Equal(t, &core.Identifier{ResourceType: 3, Project: "project", Domain: "domain", Name: "name", Version: "version"}, provider.GetID())
}

func TestGetExpectedInputs(t *testing.T) {
	provider := getProviderForTest(t)
	assert.Contains(t, (*provider.GetExpectedInputs()).GetParameters(), "foo")
	assert.NotNil(t, (*provider.GetExpectedInputs()).GetParameters()["foo"].GetVar().GetType().GetSimple())
	assert.EqualValues(t, "STRING", (*provider.GetExpectedInputs()).GetParameters()["foo"].GetVar().GetType().GetSimple().String())
	assert.NotNil(t, (*provider.GetExpectedInputs()).GetParameters()["foo"].GetDefault())
}

func TestGetExpectedOutputs(t *testing.T) {
	provider := getProviderForTest(t)
	assert.EqualValues(t, outputs.GetVariables()["foo"].GetType().GetType(),
		provider.GetExpectedOutputs().GetVariables()["foo"].GetType().GetType())
}
