🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/optseq/src/Main.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package example.optseq
import mainargs.{main, arg, ParserForMethods, ArgReader}
import mainargs.{main, arg, ParserForMethods, TokensReader}

object Main {
@main
Expand Down
125 changes: 96 additions & 29 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,34 @@ object Macros {
val companionModuleType = typeSymbolOfB.companionModule.tree.asInstanceOf[ValDef].tpt.tpe.asType
val companionModuleExpr = Ident(companionModule).asExpr
val mainAnnotationInstance = typeSymbolOfB.getAnnotation(mainAnnotation).getOrElse {
report.throwError(
s"cannot find @main annotation on ${companionModule.name}",
typeSymbolOfB.pos.get
'{new mainargs.main()}.asTerm // construct a default if not found.
}
val ctor = typeSymbolOfB.primaryConstructor
val ctorParams = ctor.paramSymss.flatten
// try to match the apply method with the constructor parameters, this is a good heuristic
// for if the apply method is overloaded.
val annotatedMethod = typeSymbolOfB.companionModule.memberMethod("apply").filter(p =>
p.paramSymss.flatten.corresponds(ctorParams) { (p1, p2) =>
p1.name == p2.name
}
).headOption.getOrElse {
report.errorAndAbort(
s"Cannot find apply method in companion object of ${typeReprOfB.show}",
typeSymbolOfB.companionModule.pos.getOrElse(Position.ofMacroExpansion)
)
}
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
companionModuleType match
case '[bCompanion] =>
val mainData = createMainData[B, Any](
val mainData = createMainData[B, bCompanion](
annotatedMethod,
mainAnnotationInstance,
// Somehow the `apply` method parameter annotations don't end up on
// the `apply` method parameters, but end up in the `<init>` method
// parameters, so use those for getting the annotations instead
TypeRepr.of[B].typeSymbol.primaryConstructor.paramSymss
)
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
val erasedMainData = '{$mainData.asInstanceOf[MainData[B, Any]]}
'{ new ParserForClass[B]($erasedMainData, () => ${ Ident(companionModule).asExpr }) }
}

def createMainData[T: Type, B: Type](using Quotes)
Expand All @@ -57,41 +68,84 @@ object Macros {
createMainData[T, B](method, mainAnnotation, method.paramSymss)
}

private object VarargTypeRepr {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Option[quotes.reflect.TypeRepr] = {
import quotes.reflect.*
tpe match {
case AnnotatedType(AppliedType(_, Seq(arg)), x)
if x.tpe =:= defn.RepeatedAnnot.typeRef => Some(arg)
case _ => None
}
}
}

private object AsType {
def unapply(using Quotes)(tpe: quotes.reflect.TypeRepr): Some[Type[?]] = {
Some(tpe.asType)
}
}

def createMainData[T: Type, B: Type](using Quotes)
(method: quotes.reflect.Symbol,
mainAnnotation: quotes.reflect.Term,
annotatedParamsLists: List[List[quotes.reflect.Symbol]]): Expr[MainData[T, B]] = {

import quotes.reflect.*
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
val defaultParams = getDefaultParams(method)
val defaultParams = if (params.exists(_.flags.is(Flags.HasDefault))) getDefaultParams(method) else Map.empty
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
val param = paramAndAnnotParam._1
val annotParam = paramAndAnnotParam._2
val paramTree = param.tree.asInstanceOf[ValDef]
val paramTpe = paramTree.tpt.tpe
val readerTpe = paramTpe match {
case VarargTypeRepr(AsType('[t])) => TypeRepr.of[Leftover[t]]
case _ => paramTpe
}
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
val paramType = paramTpe.asType
paramType match
readerTpe.asType match {
case '[t] =>
def applyAndCast(f: Expr[Any] => Expr[Any], arg: Expr[B]): Expr[t] = {
f(arg) match {
case '{ $v: `t` } => v
case expr => {
// this case will be activated when the found default parameter is not of type `t`
val recoveredType =
try
expr.asExprOf[t]
catch
case err: Exception =>
report.errorAndAbort(
s"""Failed to convert default value for parameter ${param.name},
|expected type: ${paramTpe.show},
|but default value ${expr.show} is of type: ${expr.asTerm.tpe.widen.show}
|while converting type caught an exception with message: ${err.getMessage}
|There might be a bug in mainargs.""".stripMargin,
param.pos.getOrElse(Position.ofMacroExpansion)
)
recoveredType
}
}
}
val defaultParam: Expr[Option[B => t]] = defaultParams.get(param) match {
case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) }
case Some(f) => '{ Some((b: B) => ${ applyAndCast(f, 'b) }) }
case None => '{ None }
}
val tokensReader = Expr.summon[mainargs.TokensReader[t]].getOrElse {
report.throwError(
s"No mainargs.ArgReader found for parameter ${param.name}",
param.pos.get
report.errorAndAbort(
s"No mainargs.TokensReader[${Type.show[t]}] found for parameter ${param.name} of method ${method.name} in ${method.owner.fullName}",
method.pos.getOrElse(Position.ofMacroExpansion)
)
}
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
}
}
val argSigs = Expr.ofList(argSigsExprs)

val invokeRaw: Expr[(B, Seq[Any]) => T] = {

def callOf(methodOwner: Expr[Any], args: Expr[Seq[Any]]) =
call(methodOwner, method, '{ Seq($args) }).asExprOf[T]
call(methodOwner, method, args).asExprOf[T]

'{ (b: B, params: Seq[Any]) => ${ callOf('b, 'params) } }
}
Expand Down Expand Up @@ -120,37 +174,50 @@ object Macros {
private def call(using Quotes)(
methodOwner: Expr[Any],
method: quotes.reflect.Symbol,
argss: Expr[Seq[Seq[Any]]]
args: Expr[Seq[Any]]
): Expr[_] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L106
import quotes.reflect._
val paramss = method.paramSymss

if (paramss.isEmpty) {
report.throwError("At least one parameter list must be declared.", method.pos.get)
report.errorAndAbort("At least one parameter list must be declared.", method.pos.get)
}

val accesses: List[List[Term]] = for (i <- paramss.indices.toList) yield {
for (j <- paramss(i).indices.toList) yield {
val tpe = paramss(i)(j).tree.asInstanceOf[ValDef].tpt.tpe
tpe.asType match
case '[t] => '{ $argss(${Expr(i)})(${Expr(j)}).asInstanceOf[t] }.asTerm
}
if (paramss.sizeIs > 1) {
report.errorAndAbort("Multiple parameter lists are not supported.", method.pos.get)
}
val params = paramss.head

val methodType = methodOwner.asTerm.tpe.memberType(method)

def accesses(ref: Expr[Seq[Any]]): List[Term] =
for (i <- params.indices.toList) yield {
val param = params(i)
val tpe = methodType.memberType(param)
val untypedRef = '{ $ref(${Expr(i)}) }
tpe match {
case VarargTypeRepr(AsType('[t])) =>
Typed(
'{ $untypedRef.asInstanceOf[Leftover[t]].value }.asTerm,
Inferred(AppliedType(defn.RepeatedParamClass.typeRef, List(TypeRepr.of[t])))
)
case _ => tpe.asType match
case '[t] => '{ $untypedRef.asInstanceOf[t] }.asTerm
}
}

methodOwner.asTerm.select(method).appliedToArgss(accesses).asExpr
methodOwner.asTerm.select(method).appliedToArgs(accesses(args)).asExpr
}


/** Lookup default values for a method's parameters. */
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any]] = {
private def getDefaultParams(using Quotes)(method: quotes.reflect.Symbol): Map[quotes.reflect.Symbol, Expr[Any] => Expr[Any]] = {
// Copy pasted from Cask.
// https://github.com/com-lihaoyi/cask/blob/65b9c8e4fd528feb71575f6e5ef7b5e2e16abbd9/cask/src-3/cask/router/Macros.scala#L38
import quotes.reflect._

val params = method.paramSymss.flatten
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any] => Expr[Any]]

val Name = (method.name + """\$default\$(\d+)""").r
val InitName = """\$lessinit\$greater\$default\$(\d+)""".r
Expand All @@ -159,13 +226,13 @@ object Macros {

idents.foreach{
case deff @ DefDef(Name(idx), _, _, _) =>
val expr = Ref(deff.symbol).asExpr
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

// The `apply` method re-uses the default param factory methods from `<init>`,
// so make sure to check if those exist too
case deff @ DefDef(InitName(idx), _, _, _) if method.name == "apply" =>
val expr = Ref(deff.symbol).asExpr
val expr = (owner: Expr[Any]) => Select(owner.asTerm, deff.symbol).asExpr
defaults += (params(idx.toInt - 1) -> expr)

case _ =>
Expand Down
69 changes: 69 additions & 0 deletions mainargs/test/src/ClassTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,56 @@ object ClassTests extends TestSuite {
@main
case class Qux(moo: String, b: Bar)

case class Cli(@arg(short = 'd') debug: Flag)

@main
class Compat(
@arg(short = 'h') val home: String,
@arg(short = 's') val silent: Flag,
val leftoverArgs: Leftover[String]
) {
override def equals(obj: Any): Boolean =
obj match {
case c: Compat =>
home == c.home && silent == c.silent && leftoverArgs == c.leftoverArgs
case _ => false
}
}
object Compat {
def apply(
home: String = "/home",
silent: Flag = Flag(),
leftoverArgs: Leftover[String] = Leftover()
) = new Compat(home, silent, leftoverArgs)

@deprecated("bin-compat shim", "0.1.0")
private[mainargs] def apply(
home: String,
silent: Flag,
noDefaultPredef: Flag,
leftoverArgs: Leftover[String]
) = new Compat(home, silent, leftoverArgs)
}

implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
implicit val barParser: ParserForClass[Bar] = ParserForClass[Bar]
implicit val quxParser: ParserForClass[Qux] = ParserForClass[Qux]
implicit val cliParser: ParserForClass[Cli] = ParserForClass[Cli]
implicit val compatParser: ParserForClass[Compat] = ParserForClass[Compat]

class PathWrap {
@main
case class Foo(x: Int = 23, y: Int = 47)

object Main {
@main
def run(bar: Bar, bool: Boolean = false) = {
s"${bar.w.value} ${bar.f.x} ${bar.f.y} ${bar.zzzz} $bool"
}
}

implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
}

object Main {
@main
Expand Down Expand Up @@ -161,5 +208,27 @@ object ClassTests extends TestSuite {
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
test("mill-compat") {
test("apply-overload-class") {
compatParser.constructOrThrow(Seq("foo")) ==> Compat(
home = "/home",
silent = Flag(false),
leftoverArgs = Leftover("foo")
)
}
test("no-main-on-class") {
cliParser.constructOrThrow(Seq("-d")) ==> Cli(Flag(true))
}
test("path-dependent-default") {
val p = new PathWrap
p.fooParser.constructOrThrow(Seq()) ==> p.Foo(23, 47)
}
test("path-dependent-default-method") {
val p = new PathWrap
ParserForMethods(p.Main).runOrThrow(
Seq("-x", "1", "-y", "2", "-z", "hello")
) ==> "false 1 2 hello false"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object VarargsOldTests extends VarargsBaseTests {

@main
def mixedVariadic(@arg(short = 'f') first: Int, args: String*) =
first + args.mkString
first.toString + args.mkString
}

val check = new Checker(ParserForMethods(Base), allowPositional = true)
Expand Down