Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/conv/solver_finders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ class DirectSolverFinder : public SolversFinderMixin<ProblemDescription, ConvFin
std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const ConvFindParameters&) const override
const ConvFindParameters&,
const std::optional<FindOptions>&) const override
{
/// \todo: actually use FindOptions
return problem.GetDirection() != conv::Direction::BackwardWeights
? FindAllDirectSolutions(ctx, problem, invoke_ctx)
: FindAllBwdWrW2DSolutions(ctx, problem, invoke_ctx);
Expand All @@ -91,8 +93,10 @@ class ImplicitGemmSolverFinder : public SolversFinderMixin<ProblemDescription, C
std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const ConvFindParameters&) const override
const ConvFindParameters&,
const std::optional<FindOptions>&) const override
{
/// \todo: actually use FindOptions
return problem.GetDirection() != conv::Direction::BackwardWeights
? FindAllImplicitGemmSolutions(ctx, problem, invoke_ctx)
: FindImplicitGemmWrWAllSolutions(ctx, problem, invoke_ctx);
Expand Down Expand Up @@ -120,8 +124,10 @@ class FftSolverFinder : public SolversFinderMixin<ProblemDescription, ConvFindPa
std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const ConvFindParameters&) const override
const ConvFindParameters&,
const std::optional<FindOptions>&) const override
{
/// \todo: actually use FindOptions
return FindAllFFTSolutions(ctx, problem, invoke_ctx);
}
};
Expand All @@ -145,8 +151,10 @@ class GemmSolverFinder : public SolversFinderMixin<ProblemDescription, ConvFindP
std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const ConvFindParameters&) const override
const ConvFindParameters&,
const std::optional<FindOptions>&) const override
{
/// \todo: actually use FindOptions
return FindAllGemmSolutions(ctx, problem, invoke_ctx);
}
};
Expand All @@ -170,8 +178,10 @@ class WinogradSolverFinder : public SolversFinderMixin<ProblemDescription, ConvF
std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const ConvFindParameters& parameters) const override
const ConvFindParameters& parameters,
const std::optional<FindOptions>&) const override
{
/// \todo: actually use FindOptions
auto ctx_copy = ctx;
if(parameters.use_winograd_only)
ctx_copy.use_dynamic_solutions_only = true;
Expand Down Expand Up @@ -283,7 +293,8 @@ void FindCore(const AnyInvokeParams& invoke_ctx,
const ExecutionContext& ctx,
const ProblemDescriptionBase& problem,
const PrimitiveFindParameters& parameters,
const std::vector<std::unique_ptr<ISolversFinder>>& finders)
const std::vector<std::unique_ptr<ISolversFinder>>& finders,
const std::optional<FindOptions>& options)
{
auto& handle = ctx.GetStream();

Expand All @@ -292,7 +303,7 @@ void FindCore(const AnyInvokeParams& invoke_ctx,
std::transform(
finders.begin(), finders.end(), std::inserter(solutions, solutions.end()), [&](auto&& f) {
return std::make_pair(f->GetAlgorithmName(problem),
f->Find(ctx, problem, invoke_ctx, parameters));
f->Find(ctx, problem, invoke_ctx, parameters, options));
});

// Precompile
Expand Down
100 changes: 69 additions & 31 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,12 @@ miopenStatus_t ConvBiasActivFusion(Handle& handle,
}

static auto
AllocateBuffersAndMakeFusionInvokeParams(const FusionContext& context,
AllocateBuffersAndMakeFusionInvokeParams(Handle& handle,
const FusionDescription& problem,
std::vector<Allocator::ManageDataPtr>& invoke_bufs,
miopen::OperatorArgs& params,
const FusionPlanDescriptor& plan)
{
auto& handle = context.GetStream();

const auto allocate_buffer = [&](std::size_t size) {
auto ptr = handle.Create(size);
auto ret = ptr.get();
Expand Down Expand Up @@ -709,6 +707,29 @@ static auto GetFusedWinogradSolvers()
solver::fusion::ConvBinWinogradRxSf2x3g1Fused>{};
}

static auto GetAllFusionSolvers()
{
return GetFusedNonConvSolvers() + GetFusedDirectSolvers() + GetFusedIGemmSolvers() +
GetFusedWinogradSolvers();
}

solver::ConvSolution MakeFusedSolution(const FusionContext& ctx,
solver::Id id,
const std::optional<std::string>& perf_cfg_override,
const FusionDescription& problem,
const AnyInvokeParams& invoke_params)
{
decltype(auto) db = GetDb(ctx);
solver::ConvSolution solution{miopenStatusInternalError};

GetAllFusionSolvers().FindById(id, [&](auto solver) {
solution = miopen::solver::FindSolution(
solver, ctx, problem, db, invoke_params, perf_cfg_override.value_or(""));
});

return solution;
}

struct FusionFindParameters : PrimitiveFindParameters
{
};
Expand All @@ -732,13 +753,19 @@ class FusionSolverFinder : public SolversFinderMixin<FusionDescription, FusionFi
return true;
}

std::vector<solver::ConvSolution> FindImpl(const ExecutionContext& ctx,
const FusionDescription& problem,
const AnyInvokeParams& invoke_ctx,
const FusionFindParameters&) const override
std::vector<solver::ConvSolution>
FindImpl(const ExecutionContext& ctx,
const FusionDescription& problem,
const AnyInvokeParams& invoke_ctx,
const FusionFindParameters&,
const std::optional<FindOptions>& options) const override
{
return solvers.SearchForAllSolutions(
dynamic_cast<const FusionContext&>(ctx), problem, miopen::GetDb(ctx), invoke_ctx);
return solvers.SearchForAllSolutions(dynamic_cast<const FusionContext&>(ctx),
problem,
miopen::GetDb(ctx),
invoke_ctx,
std::numeric_limits<std::size_t>::max(),
options);
}

private:
Expand All @@ -763,47 +790,50 @@ static const std::vector<std::unique_ptr<ISolversFinder>>& GetFusionSolverFinder
return finders;
}

miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
static std::vector<PerfField>
FindFusion(const ExecutionContext& ctx,
const FusionDescription& fusion_problem,
const std::function<fusion::FusionInvokeParams()>& invoke_params,
const std::optional<FindOptions>& options = std::nullopt)
{
auto fusion_ctx = FusionContext{handle};
auto fusion_problem = FusionDescription{this};
const FindEnforce enforce;

// sols is a collection of ConvSolutions that have been returned from Find for the
// fusion_problem. These ConvSolutions store instructions on how to build kernels and an invoker
// factory.
std::vector<miopen::solver::ConvSolution> sols;

auto find_results = UserFindDbRecord::TryLoad(
handle,
return UserFindDbRecord::TryLoad(
ctx.GetStream(),
fusion_problem,
[&](DbRecord& record) {
// fusion_ctx.use_dynamic_solutions_only = findMode.IsDynamicHybrid(fusion_ctx);

// We need buffers for find, thus we allocate them.
miopen::OperatorArgs params;
std::vector<Allocator::ManageDataPtr> invoke_bufs;
const auto invoke_params = AllocateBuffersAndMakeFusionInvokeParams(
fusion_ctx, fusion_problem, invoke_bufs, params, *this);

FindCore(invoke_params,
// We need buffers for find, thus we lazily get them, possibly allocating.
auto fusion_ctx = FusionContext(ctx.GetStream());
FindCore(invoke_params(),
record,
fusion_ctx,
fusion_problem,
FusionFindParameters{},
GetFusionSolverFinders());
GetFusionSolverFinders(),
options);
},
"fusion");
}

miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
{
std::vector<Allocator::ManageDataPtr> invoke_bufs;
miopen::OperatorArgs params;

const auto find_results = Find(handle, [&]() {
return AllocateBuffersAndMakeFusionInvokeParams(
handle, FusionDescription{this}, invoke_bufs, params, *this);
});

const auto network_config = fusion_problem.MakeNetworkConfig();
const auto network_config = FusionDescription{this}.MakeNetworkConfig();

for(const auto& result : find_results)
{
if(conv_fwd_algo && result.algorithm != "fusion" &&
miopen::StringToConvolutionFwdAlgo(result.algorithm) != *conv_fwd_algo)
continue;
const auto id = solver::Id{result.solver_id};

const auto id = solver::Id{result.solver_id};
const auto invoker = handle.GetInvoker(network_config, id);

if(!invoker)
Expand All @@ -825,6 +855,14 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
return miopenStatusSuccess;
}

std::vector<struct PerfField>
FusionPlanDescriptor::Find(Handle& handle,
const std::function<fusion::FusionInvokeParams()>& invoke_params,
const std::optional<FindOptions>& options) const
{
return FindFusion(&handle, this, invoke_params, options);
}

miopenStatus_t FusionPlanDescriptor::Execute(const Handle& handle,
const TensorDescriptor& inputDesc,
ConstData_t input,
Expand Down
18 changes: 18 additions & 0 deletions src/generic_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,27 @@

namespace miopen {
namespace solver {
namespace debug {

// NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables)
static std::optional<std::size_t> tuning_iterations_limit;

TuningIterationScopedLimiter::TuningIterationScopedLimiter(std::size_t new_limit)
: old_limit(tuning_iterations_limit)
{
tuning_iterations_limit = new_limit;
}

TuningIterationScopedLimiter::~TuningIterationScopedLimiter()
{
tuning_iterations_limit = old_limit;
}
} // namespace debug

std::size_t GetTuningIterationsMax()
{
if(debug::tuning_iterations_limit)
return *debug::tuning_iterations_limit;
return Value(MIOPEN_DEBUG_TUNING_ITERATIONS_MAX{}, std::numeric_limits<std::size_t>::max());
}

Expand Down
23 changes: 15 additions & 8 deletions src/include/miopen/conv/solver_finders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
#pragma once

#include <miopen/conv_solution.hpp>
#include <miopen/errors.hpp>
#include <miopen/execution_context.hpp>
#include <miopen/problem_description_base.hpp>
#include <miopen/errors.hpp>
#include <miopen/search_options.hpp>

#include <memory>
#include <type_traits>
Expand Down Expand Up @@ -58,7 +59,8 @@ class ISolversFinder
Find(const ExecutionContext& ctx,
const ProblemDescriptionBase& problem,
const AnyInvokeParams& invoke_ctx,
const PrimitiveFindParameters& parameters) const
const PrimitiveFindParameters& parameters,
const std::optional<FindOptions>& find_options) const
{
if(!IsEnabled(ctx, problem, parameters))
{
Expand All @@ -69,7 +71,7 @@ class ISolversFinder
try
{
MIOPEN_LOG_I2("Starting find for " << GetAlgorithmName(problem).ToString());
return FindImpl(ctx, problem, invoke_ctx, parameters);
return FindImpl(ctx, problem, invoke_ctx, parameters, find_options);
}
catch(Exception& ex)
{
Expand All @@ -86,7 +88,8 @@ class ISolversFinder
FindImpl(const ExecutionContext& ctx,
const ProblemDescriptionBase& problem,
const AnyInvokeParams& invoke_ctx,
const PrimitiveFindParameters& parameters) const = 0;
const PrimitiveFindParameters& parameters,
const std::optional<FindOptions>& options) const = 0;
};

template <class ProblemDescription, class FindParameters>
Expand All @@ -105,12 +108,14 @@ class SolversFinderMixin : public ISolversFinder
FindImpl(const ExecutionContext& ctx,
const ProblemDescriptionBase& problem,
const AnyInvokeParams& invoke_ctx,
const PrimitiveFindParameters& parameters) const final
const PrimitiveFindParameters& parameters,
const std::optional<FindOptions>& options) const final
{
return FindImpl(ctx,
static_cast<const ProblemDescription&>(problem),
invoke_ctx,
static_cast<const FindParameters&>(parameters));
static_cast<const FindParameters&>(parameters),
options);
}

[[nodiscard]] bool IsEnabled(const ExecutionContext& ctx,
Expand All @@ -130,7 +135,8 @@ class SolversFinderMixin : public ISolversFinder
FindImpl(const ExecutionContext& ctx,
const ProblemDescription& problem,
const AnyInvokeParams& invoke_ctx,
const FindParameters& parameters) const = 0;
const FindParameters& parameters,
const std::optional<FindOptions>& options) const = 0;

[[nodiscard]] virtual bool IsEnabled(const ExecutionContext& ctx,
const ProblemDescription& problem,
Expand All @@ -148,7 +154,8 @@ void FindCore(const AnyInvokeParams& invoke_ctx,
const ExecutionContext& ctx,
const ProblemDescriptionBase& problem,
const PrimitiveFindParameters& parameters,
const std::vector<std::unique_ptr<ISolversFinder>>& finders);
const std::vector<std::unique_ptr<ISolversFinder>>& finders,
const std::optional<FindOptions>& options = std::nullopt);

namespace conv {

Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/find_controls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class FindEnforce

public:
FindEnforce();
explicit FindEnforce(FindEnforceAction action_) : action(action_) {}

template <class Context>
bool IsDbClean(const Context& context) const
Expand Down
Loading