2

データフレームに読み込まれる csv ファイルがあります。1 つの列の値に基づいて、トレーニング ファイルとテスト ファイルに分割します。

列が「カテゴリ」と呼ばれ、複数回繰り返される cat1、cat2、cat3 などの列値としていくつかのカテゴリ名があるとします。

各カテゴリ名が両方のファイルに少なくとも 1 回含まれるように、ファイルを分割する必要があります。

これまでのところ、比率に基づいてファイルを 2 つに分割できました。私は多くのオプションを試しましたが、これは今のところ最高のものです。

  def executeSplitData(self):
      data = self.readCSV() 
      df = data
      if self.column in data:
         train, test = train_test_split(df, stratify = None, test_size=0.5)
         self.writeTrainFile(train)
         self.writeTestFile(test)

test_train_split の stratify オプションがよくわかりません。助けてください。ありがとう

4

1 に答える 1