// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

//
// This module define MatMulNBits operator, it is basically
// matmul float with right hand side being a 2-D matrix
// pre-packed and block-compacted into int4
//
#pragma once
#include "core/common/safeint.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;

template <typename T>
class MatMulNBits final : public RocmKernel {
 public:
  MatMulNBits(const OpKernelInfo& info) : RocmKernel(info) {
    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
    ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("bits", &nbits_));
  }

  Status ComputeInternal(OpKernelContext* context) const override;

 private:
  int64_t K_;
  int64_t N_;
  int64_t block_size_;
  int64_t nbits_;
  bool column_wise_quant_blk_{true};
};

}  // namespace rocm
}  // namespace contrib
}  // namespace onnxruntime
