chainer 2.0勉強記録
Livedoorブログからの移動
はてなブログからの移動
チュートリアルを一通りやった上で,追加でハマったことのメモ
Dataset
の作成方法
以下のように,DatasetMixin
を使って作る.データの生成(ファイル読み込みなど)をget_example
まで遅延させるのもOK
class Dataset(chainer.dataset.DatasetMixin): def __init__(self): xs = np.array(np.random.uniform(-math.pi, math.pi, (10000, 1)), dtype=np.float32) f = lambda t: np.array([math.sin(t)], dtype=np.float32) ys = np.array([f(e) for e in xs]) self.input = xs self.output = ys def __len__(self): return len(self.output) def get_example(self, i): return self.input[i], self.output[i]
Trainer
に与えるModel
の実装方法
__call__
の引数
第一引数が入力,第二引数が出力(actual
)となり,loss
関数を実装する.
__call__
での引数の扱い
上記Datasetの場合,引数がただの変数なので,Variable(...)
でラップする必要がある.
その他
計算途中を変数に置かないと上手く動かないケースがある(? 未検証)
例えば,
h1 = l1(x)
h2 = l2(h1)
return h2
は上手く動くが,
return l2(l2(x))
は上手く動かないような気がする.